From f353eb056a335760ff2390e1ca6f42c7b031d260 Mon Sep 17 00:00:00 2001 From: unraid Date: Sun, 26 Apr 2026 21:44:42 +0800 Subject: [PATCH 1/3] fix: bound agent communication memory growth UDS messaging now uses private local capabilities instead of exposing auth tokens through SDK metadata, environment variables, session registry, peer listing, or tool output. The receive path bounds NDJSON frames, response buffers, active clients, and pending inbox bytes, and strips auth metadata before messages enter the prompt queue. Teammate mailboxes now validate file and message sizes, fail closed on corrupt mutation inputs, compact by count and retained bytes, and use stable message identity for in-process acknowledgements. Agent summaries now fork only a bounded recent context using lazy size estimation and content fingerprints instead of retaining or serializing unbounded histories. Constraint: PR #361 was already merged; this branch is based on upstream/main@c2ac9a74. Rejected: Default-disabling COORDINATOR_MODE/TEAMMEM only | explicit feature enablement still hit unbounded paths. Rejected: Persisting UDS auth in SDK/env/session registry | bridge/remote metadata can leak local capability secrets. Rejected: Inline uds #token addresses | observable/tool/classifier paths can reflect raw addresses outside the UDS request frame. Rejected: Positional mailbox marking after compaction | compaction can shift indices across the lock boundary. Confidence: high Scope-risk: moderate Directive: Do not expose UDS capability tokens through SDK messages, environment variables, session registry, peer-list output, or SendMessage result/classifier surfaces. Directive: Do not reintroduce positional mailbox acknowledgements unless compaction is removed or read+mark is atomic under one lock. Tested: bun test src/utils/__tests__/ndjsonFramer.test.ts src/utils/__tests__/udsMessaging.test.ts packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts Tested: bunx tsc --noEmit --pretty false Tested: bun run lint Tested: bunx biome lint modified src/package files Tested: bun run test:all (3704 pass, 0 fail, 6734 expects) Tested: bun audit (No vulnerabilities found) Tested: bun run build Tested: bun run build:vite Tested: git diff --check Not-tested: End-to-end external UDS client driving a full production headless model turn. --- .../src/tools/ListPeersTool/ListPeersTool.ts | 24 +- .../tools/SendMessageTool/SendMessageTool.ts | 81 ++- .../udsRecipientSanitization.test.ts | 41 ++ src/cli/print.ts | 28 +- src/commands/peers/peers.ts | 15 +- .../__tests__/summaryContext.test.ts | 132 +++++ src/services/AgentSummary/agentSummary.ts | 34 +- src/services/AgentSummary/summaryContext.ts | 175 ++++++ src/utils/__tests__/ndjsonFramer.test.ts | 91 ++++ src/utils/__tests__/teammateMailbox.test.ts | 310 +++++++++++ src/utils/__tests__/udsMessaging.test.ts | 305 +++++++++++ src/utils/messages/systemInit.ts | 4 +- src/utils/ndjsonFramer.ts | 38 ++ src/utils/swarm/inProcessRunner.ts | 28 +- src/utils/teammateMailbox.ts | 303 ++++++++++- src/utils/udsClient.ts | 114 +++- src/utils/udsMessaging.ts | 499 +++++++++++++++--- 17 files changed, 2086 insertions(+), 136 deletions(-) create mode 100644 packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts create mode 100644 src/services/AgentSummary/__tests__/summaryContext.test.ts create mode 100644 src/services/AgentSummary/summaryContext.ts create mode 100644 src/utils/__tests__/ndjsonFramer.test.ts create mode 100644 src/utils/__tests__/teammateMailbox.test.ts create mode 100644 src/utils/__tests__/udsMessaging.test.ts diff --git a/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts b/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts index e520243a53..48219edc72 100644 --- a/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts +++ b/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts @@ -85,21 +85,35 @@ Use this tool to discover messaging targets before sending cross-session message // and optionally includes Remote Control bridge peers. const peers: PeerInfo[] = [] - // Discovery is handled by the UDS messaging subsystem initialized in setup.ts. - // Return discovered peers from the app state. - const appState = context.getAppState() - const messagingSocketPath = (appState as Record).messagingSocketPath as string | undefined + /* eslint-disable @typescript-eslint/no-require-imports */ + const udsMessaging = + require('src/utils/udsMessaging.js') as typeof import('src/utils/udsMessaging.js') + const udsClient = + require('src/utils/udsClient.js') as typeof import('src/utils/udsClient.js') + /* eslint-enable @typescript-eslint/no-require-imports */ + + const messagingSocketPath = udsMessaging.getUdsMessagingSocketPath() if (messagingSocketPath) { // Self entry for reference if (_input.include_self) { peers.push({ - address: `uds:${messagingSocketPath}`, + address: udsMessaging.formatUdsAddress(messagingSocketPath), name: 'self', pid: process.pid, }) } } + for (const peer of await udsClient.listPeers()) { + if (!peer.messagingSocketPath) continue + peers.push({ + address: udsMessaging.formatUdsAddress(peer.messagingSocketPath), + name: peer.name ?? peer.kind, + cwd: peer.cwd, + pid: peer.pid, + }) + } + return { data: { peers }, } diff --git a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts index 4e97370518..3548544fca 100644 --- a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts +++ b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts @@ -130,6 +130,43 @@ export type SendMessageToolOutput = | RequestOutput | ResponseOutput +const UDS_INLINE_TOKEN_MARKER = '#token=' +const UDS_INLINE_TOKEN_REJECTED_KEY = '__udsInlineTokenRejected' + +function stripInlineUdsToken(target: string): string { + const markerIndex = target.lastIndexOf(UDS_INLINE_TOKEN_MARKER) + return markerIndex === -1 ? target : target.slice(0, markerIndex) +} + +function hasInlineUdsToken(to: string): boolean { + const addr = parseAddress(to) + return ( + addr.scheme === 'uds' && addr.target.includes(UDS_INLINE_TOKEN_MARKER) + ) +} + +function recipientForDisplay(to: string): string { + const addr = parseAddress(to) + if (addr.scheme !== 'uds') return to + return `uds:${stripInlineUdsToken(addr.target)}` +} + +function markAndRedactInlineUdsToken( + input: { to: string } & Record, +): void { + if (!hasInlineUdsToken(input.to)) return + input.to = recipientForDisplay(input.to) + input[UDS_INLINE_TOKEN_REJECTED_KEY] = true +} + +function wasInlineUdsTokenRejected(input: unknown): boolean { + return ( + typeof input === 'object' && + input !== null && + (input as Record)[UDS_INLINE_TOKEN_REJECTED_KEY] === true + ) +} + function findTeammateColor( appState: { teamContext?: { teammates: { [id: string]: { color?: string } } } @@ -541,15 +578,19 @@ export const SendMessageTool: Tool = }, backfillObservableInput(input) { - if ('type' in input) return if (typeof input.to !== 'string') return + markAndRedactInlineUdsToken( + input as { to: string } & Record, + ) + if ('type' in input) return + if (input.to === '*') { input.type = 'broadcast' if (typeof input.message === 'string') input.content = input.message } else if (typeof input.message === 'string') { input.type = 'message' - input.recipient = input.to + input.recipient = recipientForDisplay(input.to) input.content = input.message } else if (typeof input.message === 'object' && input.message !== null) { const msg = input.message as { @@ -560,7 +601,7 @@ export const SendMessageTool: Tool = feedback?: string } input.type = msg.type - input.recipient = input.to + input.recipient = recipientForDisplay(input.to) if (msg.request_id !== undefined) input.request_id = msg.request_id if (msg.approve !== undefined) input.approve = msg.approve const content = msg.reason ?? msg.feedback @@ -569,16 +610,17 @@ export const SendMessageTool: Tool = }, toAutoClassifierInput(input) { + const recipient = recipientForDisplay(input.to) if (typeof input.message === 'string') { - return `to ${input.to}: ${input.message}` + return `to ${recipient}: ${input.message}` } switch (input.message.type) { case 'shutdown_request': - return `shutdown_request to ${input.to}` + return `shutdown_request to ${recipient}` case 'shutdown_response': return `shutdown_response ${input.message.approve ? 'approve' : 'reject'} ${input.message.request_id}` case 'plan_approval_response': - return `plan_approval ${input.message.approve ? 'approve' : 'reject'} to ${input.to}` + return `plan_approval ${input.message.approve ? 'approve' : 'reject'} to ${recipient}` } }, @@ -630,6 +672,19 @@ export const SendMessageTool: Tool = errorCode: 9, } } + if (feature('UDS_INBOX')) { + if ( + addr.scheme === 'uds' && + (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) + ) { + return { + result: false, + message: + 'uds addresses must not include inline auth tokens; use the ListPeers address', + errorCode: 9, + } + } + } if (input.to.includes('@')) { return { result: false, @@ -787,6 +842,16 @@ export const SendMessageTool: Tool = } } if (addr.scheme === 'uds') { + const recipient = recipientForDisplay(input.to) + if (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) { + return { + data: { + success: false, + message: + 'uds addresses must not include inline auth tokens; use the ListPeers address', + }, + } + } /* eslint-disable @typescript-eslint/no-require-imports */ const { sendToUdsSocket } = require('src/utils/udsClient.js') as typeof import('src/utils/udsClient.js') @@ -797,14 +862,14 @@ export const SendMessageTool: Tool = return { data: { success: true, - message: `”${preview}” → ${input.to}`, + message: `”${preview}” → ${recipient}`, }, } } catch (e) { return { data: { success: false, - message: `Failed to send to ${input.to}: ${errorMessage(e)}`, + message: `Failed to send to ${recipient}: ${errorMessage(e)}`, }, } } diff --git a/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts new file mode 100644 index 0000000000..20124b6c38 --- /dev/null +++ b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts @@ -0,0 +1,41 @@ +import { describe, expect, mock, test } from 'bun:test' + +mock.module('bun:bundle', () => ({ + feature: (name: string) => name === 'UDS_INBOX', +})) + +describe('SendMessageTool UDS recipient handling', () => { + test('redacts inline UDS tokens before classifier and observable paths', async () => { + const { SendMessageTool } = await import('../SendMessageTool.js') + const tokenAddress = 'uds:/tmp/peer.sock#token=secret-token' + + const observableInput = { + to: tokenAddress, + message: 'hello', + } as Record + SendMessageTool.backfillObservableInput!(observableInput) + + expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') + expect(JSON.stringify(observableInput)).not.toContain('secret-token') + expect( + SendMessageTool.toAutoClassifierInput({ + to: tokenAddress, + message: 'hello', + }), + ).toBe('to uds:/tmp/peer.sock: hello') + }) + + test('rejects inline UDS tokens during validation', async () => { + const { SendMessageTool } = await import('../SendMessageTool.js') + const result = await SendMessageTool.validateInput!( + { + to: 'uds:/tmp/peer.sock#token=secret-token', + message: 'hello', + }, + {} as never, + ) + + expect(result.result).toBe(false) + expect(JSON.stringify(result)).not.toContain('secret-token') + }) +}) diff --git a/src/cli/print.ts b/src/cli/print.ts index eb2b543f86..7a291fb504 100644 --- a/src/cli/print.ts +++ b/src/cli/print.ts @@ -2763,13 +2763,37 @@ function runHeadlessStreaming( // when a message arrives via the UDS socket in headless mode. if (feature('UDS_INBOX')) { /* eslint-disable @typescript-eslint/no-require-imports */ - const { setOnEnqueue } = require('../utils/udsMessaging.js') + const { drainInbox, setOnEnqueue } = + require('../utils/udsMessaging.js') as typeof import('../utils/udsMessaging.js') /* eslint-enable @typescript-eslint/no-require-imports */ + + const enqueueUdsInboxMessages = (): boolean => { + const entries = drainInbox() + for (const entry of entries) { + const value = + typeof entry.message.data === 'string' + ? entry.message.data + : jsonStringify(entry.message) + enqueue({ + mode: 'prompt', + value, + uuid: randomUUID(), + }) + } + return entries.length > 0 + } + setOnEnqueue(() => { if (!inputClosed) { - void run() + if (enqueueUdsInboxMessages()) { + void run() + } } }) + + if (enqueueUdsInboxMessages()) { + void run() + } } // Cron scheduler: runs scheduled_tasks.json tasks in SDK/-p mode. diff --git a/src/commands/peers/peers.ts b/src/commands/peers/peers.ts index aed37d3279..fcb7e17a74 100644 --- a/src/commands/peers/peers.ts +++ b/src/commands/peers/peers.ts @@ -1,6 +1,9 @@ import type { LocalCommandCall } from '../../types/command.js' import { listPeers, isPeerAlive } from '../../utils/udsClient.js' -import { getUdsMessagingSocketPath } from '../../utils/udsMessaging.js' +import { + formatUdsAddress, + getUdsMessagingSocketPath, +} from '../../utils/udsMessaging.js' export const call: LocalCommandCall = async (_args, _context) => { const mySocket = getUdsMessagingSocketPath() @@ -29,11 +32,11 @@ export const call: LocalCommandCall = async (_args, _context) => { ? ` started: ${formatAge(peer.startedAt)}` : '' - lines.push( - ` [${status}] PID ${peer.pid} (${label})${cwd}${age}`, - ) + lines.push(` [${status}] PID ${peer.pid} (${label})${cwd}${age}`) if (peer.messagingSocketPath) { - lines.push(` socket: ${peer.messagingSocketPath}`) + lines.push( + ` socket: ${formatUdsAddress(peer.messagingSocketPath)}`, + ) } if (peer.sessionId) { lines.push(` session: ${peer.sessionId}`) @@ -43,7 +46,7 @@ export const call: LocalCommandCall = async (_args, _context) => { lines.push('') lines.push( - 'To message a peer: use SendMessage with to="uds:"', + 'To message a peer: use SendMessage with the shown uds: address', ) return { type: 'text', value: lines.join('\n') } diff --git a/src/services/AgentSummary/__tests__/summaryContext.test.ts b/src/services/AgentSummary/__tests__/summaryContext.test.ts new file mode 100644 index 0000000000..3ffa559645 --- /dev/null +++ b/src/services/AgentSummary/__tests__/summaryContext.test.ts @@ -0,0 +1,132 @@ +import { describe, expect, test } from 'bun:test' +import type { Message } from '../../../types/message.js' +import { + getSummaryContextFingerprint, + selectSummaryContextMessages, +} from '../summaryContext.js' + +function makeMessage( + type: 'user' | 'assistant', + uuid: string, + content: string, +): Message { + return { + type, + uuid, + message: { + role: type, + content, + }, + } as unknown as Message +} + +describe('selectSummaryContextMessages', () => { + test('keeps a bounded recent suffix that starts with a user message', () => { + const messages = [ + makeMessage('assistant', 'a0', 'older assistant'), + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'first response'), + makeMessage('user', 'u2', 'second prompt'), + makeMessage('assistant', 'a2', 'second response'), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 3, + maxChars: 1_000, + }) + + expect(selected.map(message => String(message.uuid))).toEqual(['u2', 'a2']) + }) + + test('returns no context when the newest message exceeds the byte budget', () => { + const messages = [ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'x'.repeat(100)), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 10, + maxChars: 10, + }) + + expect(selected).toEqual([]) + }) + + test('uses serialized message size for nested content budgets', () => { + const messages = [ + makeMessage('user', 'u1', 'first prompt'), + { + ...makeMessage('assistant', 'a1', 'short'), + nested: { + payload: Array.from({ length: 50 }, (_value, index) => ({ + index, + text: 'x'.repeat(20), + })), + }, + } as unknown as Message, + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 10, + maxChars: 200, + }) + + expect(selected).toEqual([]) + }) + + test('drops leading orphan tool results after bounding', () => { + const messages = [ + makeMessage('assistant', 'a0', 'older assistant'), + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [ + { type: 'tool_result', tool_use_id: 'tool-1', content: 'ok' }, + ], + }, + } as unknown as Message, + makeMessage('assistant', 'a1', 'after orphan'), + makeMessage('user', 'u2', 'next prompt'), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 3, + maxChars: 1_000, + }) + + expect(selected.map(message => String(message.uuid))).toEqual(['u2']) + }) +}) + +describe('getSummaryContextFingerprint', () => { + test('changes when the transcript grows', () => { + const messages = [ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'first response'), + ] + + const first = getSummaryContextFingerprint(messages) + const second = getSummaryContextFingerprint([ + ...messages, + makeMessage('user', 'u2', 'next prompt'), + ]) + expect(first?.startsWith('2:a1:')).toBe(true) + expect(second?.startsWith('3:u2:')).toBe(true) + expect(first).not.toBe(second) + }) + + test('changes when message content changes under the same uuid', () => { + const first = getSummaryContextFingerprint([ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'first response'), + ]) + const second = getSummaryContextFingerprint([ + makeMessage('user', 'u1', 'first prompt'), + makeMessage('assistant', 'a1', 'updated response'), + ]) + + expect(first).not.toBe(second) + }) +}) diff --git a/src/services/AgentSummary/agentSummary.ts b/src/services/AgentSummary/agentSummary.ts index 50146b3c79..2232e839d7 100644 --- a/src/services/AgentSummary/agentSummary.ts +++ b/src/services/AgentSummary/agentSummary.ts @@ -23,6 +23,10 @@ import { import { logError } from '../../utils/log.js' import { createUserMessage } from '../../utils/messages.js' import { getAgentTranscript } from '../../utils/sessionStorage.js' +import { + getSummaryContextFingerprint, + selectSummaryContextMessages, +} from './summaryContext.js' const SUMMARY_INTERVAL_MS = 30_000 @@ -58,6 +62,7 @@ export function startAgentSummarization( let timeoutId: ReturnType | null = null let stopped = false let previousSummary: string | null = null + let lastHandledTranscriptFingerprint: string | null = null async function runSummary(): Promise { if (stopped) return @@ -82,15 +87,35 @@ export function startAgentSummarization( // Filter to clean message state const cleanMessages = filterIncompleteToolCalls(transcript.messages) + const summaryContext = filterIncompleteToolCalls( + selectSummaryContextMessages(cleanMessages), + ) + const transcriptFingerprint = getSummaryContextFingerprint(summaryContext) + if ( + transcriptFingerprint && + transcriptFingerprint === lastHandledTranscriptFingerprint + ) { + logForDebugging( + `[AgentSummary] Skipping summary for ${taskId}: transcript unchanged`, + ) + return + } + + if (summaryContext.length < 3) { + logForDebugging( + `[AgentSummary] Skipping summary for ${taskId}: no bounded context available`, + ) + return + } // Build fork params with current messages const forkParams: CacheSafeParams = { ...baseParams, - forkContextMessages: cleanMessages, + forkContextMessages: summaryContext, } logForDebugging( - `[AgentSummary] Forking for summary, ${cleanMessages.length} messages in context`, + `[AgentSummary] Forking for summary, ${summaryContext.length} messages in context`, ) // Create abort controller for this summary @@ -136,13 +161,16 @@ export function startAgentSummarization( ) continue } - const contentArr = Array.isArray(msg.message!.content) ? msg.message!.content : [] + const contentArr = Array.isArray(msg.message!.content) + ? msg.message!.content + : [] const textBlock = contentArr.find(b => b.type === 'text') if (textBlock?.type === 'text' && textBlock.text.trim()) { const summaryText = textBlock.text.trim() logForDebugging( `[AgentSummary] Summary result for ${taskId}: ${summaryText}`, ) + lastHandledTranscriptFingerprint = transcriptFingerprint previousSummary = summaryText updateAgentSummary(taskId, summaryText, setAppState) break diff --git a/src/services/AgentSummary/summaryContext.ts b/src/services/AgentSummary/summaryContext.ts new file mode 100644 index 0000000000..4d9f6a6ce5 --- /dev/null +++ b/src/services/AgentSummary/summaryContext.ts @@ -0,0 +1,175 @@ +import { createHash } from 'crypto' +import type { Message } from '../../types/message.js' + +export const MAX_SUMMARY_CONTEXT_MESSAGES = 120 +export const MAX_SUMMARY_CONTEXT_CHARS = 200_000 + +function estimateJsonChars( + value: unknown, + limit: number, + seen = new Set(), +): number { + if (value === null) return 4 + switch (typeof value) { + case 'string': + return value.length + 2 + case 'number': + case 'boolean': + return String(value).length + case 'undefined': + case 'function': + case 'symbol': + return 0 + case 'object': { + if (seen.has(value)) return Number.POSITIVE_INFINITY + seen.add(value) + let total = 2 + if (Array.isArray(value)) { + for (let index = 0; index < value.length; index++) { + total += String(index).length + 3 + total += estimateJsonChars(value[index], limit - total, seen) + if (total > limit) return total + } + } else { + const record = value as Record + for (const key in record) { + if (!Object.hasOwn(record, key)) continue + total += key.length + 3 + total += estimateJsonChars(record[key], limit - total, seen) + if (total > limit) return total + } + } + seen.delete(value) + return total + } + } + return 0 +} + +function updateFingerprintHash( + hash: ReturnType, + value: unknown, + limit: { remaining: number }, + seen = new Set(), +): void { + if (limit.remaining <= 0) return + if (value === null || typeof value !== 'object') { + const text = String(value) + hash.update(typeof value) + hash.update(':') + hash.update(text.slice(0, limit.remaining)) + limit.remaining -= text.length + return + } + if (seen.has(value)) { + hash.update('[Circular]') + return + } + seen.add(value) + if (Array.isArray(value)) { + for (let index = 0; index < value.length; index++) { + if (limit.remaining <= 0) break + const key = String(index) + hash.update(key) + limit.remaining -= key.length + updateFingerprintHash(hash, value[index], limit, seen) + } + } else { + const record = value as Record + for (const key in record) { + if (limit.remaining <= 0) break + if (!Object.hasOwn(record, key)) continue + hash.update(key) + limit.remaining -= key.length + updateFingerprintHash(hash, record[key], limit, seen) + } + } + seen.delete(value) +} + +export function estimateMessageChars( + message: Message, + limit = Number.POSITIVE_INFINITY, +): number { + const estimated = estimateJsonChars(message, limit) + if (!Number.isFinite(estimated)) { + return Number.POSITIVE_INFINITY + } + return estimated +} + +function hasToolResultBlock(message: Message): boolean { + if (message.type !== 'user') return false + const content = message.message?.content + return ( + Array.isArray(content) && + content.some(block => { + return Boolean( + block && + typeof block === 'object' && + 'type' in block && + block.type === 'tool_result', + ) + }) + ) +} + +export function getSummaryContextFingerprint( + messages: Message[], +): string | null { + const lastMessage = messages.at(-1) + if (!lastMessage) return null + const hash = createHash('sha256') + updateFingerprintHash(hash, messages, { + remaining: MAX_SUMMARY_CONTEXT_CHARS, + }) + return `${messages.length}:${lastMessage.uuid}:${hash.digest('hex').slice(0, 16)}` +} + +export function selectSummaryContextMessages( + messages: Message[], + limits: { + maxMessages?: number + maxChars?: number + } = {}, +): Message[] { + const maxMessages = limits.maxMessages ?? MAX_SUMMARY_CONTEXT_MESSAGES + const maxChars = limits.maxChars ?? MAX_SUMMARY_CONTEXT_CHARS + if (maxMessages <= 0 || maxChars <= 0) return [] + + const selected: Message[] = [] + let selectedChars = 0 + + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i] + if (!message) continue + + const messageChars = estimateMessageChars(message, maxChars - selectedChars) + if (messageChars > maxChars) { + if (selected.length === 0) return [] + break + } + + if ( + selected.length >= maxMessages || + selectedChars + messageChars > maxChars + ) { + break + } + + selected.unshift(message) + selectedChars += messageChars + } + + while (selected.length > 0) { + const first = selected[0] + if (!first) break + if (first.type !== 'user' || hasToolResultBlock(first)) { + selected.shift() + continue + } + break + } + + return selected +} diff --git a/src/utils/__tests__/ndjsonFramer.test.ts b/src/utils/__tests__/ndjsonFramer.test.ts new file mode 100644 index 0000000000..35174162a3 --- /dev/null +++ b/src/utils/__tests__/ndjsonFramer.test.ts @@ -0,0 +1,91 @@ +import { EventEmitter } from 'node:events' +import type { Socket } from 'node:net' +import { describe, expect, test } from 'bun:test' +import { attachNdjsonFramer } from '../ndjsonFramer.js' + +type TestSocket = Socket & { + destroyed: boolean + emitData: (chunk: Buffer) => void +} + +function createTestSocket(): TestSocket { + const emitter = new EventEmitter() as TestSocket + emitter.destroyed = false + emitter.destroy = ((_error?: Error) => { + emitter.destroyed = true + emitter.emit('close') + return emitter + }) as TestSocket['destroy'] + emitter.emitData = (chunk: Buffer) => { + emitter.emit('data', chunk) + } + return emitter +} + +describe('attachNdjsonFramer', () => { + test('accepts a complete frame at the configured byte limit', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + maxFrameBytes: Buffer.byteLength('{"a":1}', 'utf8'), + onFrameError: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{"a":1}\n')) + + expect(messages).toEqual([{ a: 1 }]) + expect(errors).toEqual([]) + expect(socket.destroyed).toBe(false) + }) + + test('destroys a complete frame over the configured byte limit', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + maxFrameBytes: 8, + onFrameError: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{"long":true}\n')) + + expect(messages).toEqual([]) + expect(errors[0]?.message).toContain('NDJSON frame exceeded') + expect(socket.destroyed).toBe(true) + }) + + test('destroys oversized no-newline input before a frame can form', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + maxFrameBytes: 8, + onFrameError: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('x'.repeat(9))) + + expect(messages).toEqual([]) + expect(errors[0]?.message).toContain('NDJSON frame exceeded') + expect(socket.destroyed).toBe(true) + }) +}) diff --git a/src/utils/__tests__/teammateMailbox.test.ts b/src/utils/__tests__/teammateMailbox.test.ts new file mode 100644 index 0000000000..577c4331f4 --- /dev/null +++ b/src/utils/__tests__/teammateMailbox.test.ts @@ -0,0 +1,310 @@ +import { afterEach, beforeEach, describe, expect, test } from 'bun:test' +import { mkdir, readFile, rm, writeFile } from 'node:fs/promises' +import { mkdtempSync } from 'node:fs' +import { tmpdir } from 'node:os' +import { dirname, join } from 'node:path' +import { + compactMailboxMessages, + getInboxPath, + markMessageAsReadByIndex, + markMessageAsReadByIdentity, + markMessagesAsRead, + markMessagesAsReadByPredicate, + MAX_MAILBOX_MESSAGE_TEXT_BYTES, + MAX_MAILBOX_MESSAGES, + MAX_READ_MAILBOX_MESSAGES, + MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES, + readMailbox, + type TeammateMessage, + writeToMailbox, +} from '../teammateMailbox.js' + +let tempHome = '' +let previousConfigDir: string | undefined + +function message( + text: string, + read: boolean, + timestamp = new Date(0).toISOString(), +): TeammateMessage { + return { + from: 'team-lead', + text, + timestamp, + read, + } +} + +async function seedMailbox( + agentName: string, + teamName: string, + messages: TeammateMessage[], +): Promise { + const inboxPath = getInboxPath(agentName, teamName) + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, JSON.stringify(messages, null, 2), 'utf-8') +} + +async function readRawMailbox( + agentName: string, + teamName: string, +): Promise { + const content = await readFile(getInboxPath(agentName, teamName), 'utf-8') + return JSON.parse(content) as TeammateMessage[] +} + +beforeEach(() => { + previousConfigDir = process.env.CLAUDE_CONFIG_DIR + tempHome = mkdtempSync(join(tmpdir(), 'teammate-mailbox-')) + process.env.CLAUDE_CONFIG_DIR = tempHome +}) + +afterEach(async () => { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) +}) + +describe('compactMailboxMessages', () => { + test('prioritizes unread messages and keeps only recent read history', () => { + const compacted = compactMailboxMessages( + [ + message('read-1', true), + message('read-2', true), + message('unread-1', false), + message('read-3', true), + message('unread-2', false), + message('read-4', true), + message('read-5', true), + message('unread-3', false), + ], + { maxMessages: 5, maxReadMessages: 2 }, + ) + + expect(compacted.map(m => m.text)).toEqual([ + 'unread-1', + 'unread-2', + 'read-4', + 'read-5', + 'unread-3', + ]) + }) + + test('retains unread protocol messages separately from regular cap', () => { + const protocol = message( + JSON.stringify({ type: 'permission_response', request_id: 'req-1' }), + false, + ) + const compacted = compactMailboxMessages( + [ + protocol, + ...Array.from({ length: 5 }, (_value, index) => + message(`regular-${index}`, false), + ), + ], + { + maxMessages: 2, + maxReadMessages: 0, + maxUnreadProtocolMessages: 1, + }, + ) + + expect(compacted.map(m => m.text)).toEqual([ + protocol.text, + 'regular-3', + 'regular-4', + ]) + }) + + test('caps unread protocol messages with an independent bound', () => { + const compacted = compactMailboxMessages( + Array.from( + { length: MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES + 1 }, + (_value, index) => + message( + JSON.stringify({ + type: 'permission_response', + request_id: `req-${index}`, + }), + false, + ), + ), + ) + + expect(compacted).toHaveLength(MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES) + expect(compacted[0]?.text).toContain('req-1') + }) + + test('keeps retained mailbox bytes under an explicit budget', () => { + const compacted = compactMailboxMessages( + Array.from({ length: 20 }, (_value, index) => + message(`msg-${index}-${'x'.repeat(200)}`, false), + ), + { + maxMessages: 20, + maxReadMessages: 0, + maxRetainedBytes: 1_000, + }, + ) + + expect( + Buffer.byteLength(JSON.stringify(compacted), 'utf8'), + ).toBeLessThanOrEqual(1_000) + expect(compacted.length).toBeLessThan(20) + expect(compacted.at(-1)?.text).toContain('msg-19') + }) +}) + +describe('teammate mailbox retention', () => { + test('writeToMailbox compacts oversized unread inbox files', async () => { + const existing = Array.from( + { length: MAX_MAILBOX_MESSAGES + 20 }, + (_value, index) => message(`old-${index}`, false), + ) + await seedMailbox('worker', 'alpha', existing) + + await writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'newest', + timestamp: new Date(1).toISOString(), + }, + 'alpha', + ) + + const after = await readMailbox('worker', 'alpha') + expect(after).toHaveLength(MAX_MAILBOX_MESSAGES) + expect(after[0]?.text).toBe('old-21') + expect(after.at(-1)?.text).toBe('newest') + }) + + test('markMessagesAsRead compacts read history after consumption', async () => { + const existing = Array.from( + { length: MAX_MAILBOX_MESSAGES + 20 }, + (_value, index) => message(`msg-${index}`, false), + ) + await seedMailbox('worker', 'alpha', existing) + + await markMessagesAsRead('worker', 'alpha') + + const after = await readRawMailbox('worker', 'alpha') + expect(after).toHaveLength(MAX_READ_MAILBOX_MESSAGES) + expect(after.every(m => m.read)).toBe(true) + expect(after[0]?.text).toBe( + `msg-${MAX_MAILBOX_MESSAGES + 20 - MAX_READ_MAILBOX_MESSAGES}`, + ) + }) + + test('markMessagesAsReadByPredicate leaves structured messages unread', async () => { + await seedMailbox('worker', 'alpha', [ + message('plain', false), + message(JSON.stringify({ type: 'permission_request' }), false), + ]) + + await markMessagesAsReadByPredicate( + 'worker', + m => !m.text.includes('permission_request'), + 'alpha', + ) + + const after = await readRawMailbox('worker', 'alpha') + expect(after.map(m => m.read)).toEqual([true, false]) + }) + + test('markMessageAsReadByIdentity survives compaction shifting indexes', async () => { + const permissionResponse = message( + JSON.stringify({ type: 'permission_response', request_id: 'req-1' }), + false, + ) + await seedMailbox('worker', 'alpha', [ + permissionResponse, + ...Array.from({ length: MAX_MAILBOX_MESSAGES + 20 }, (_value, index) => + message(`regular-${index}`, false), + ), + ]) + + await writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'newest', + timestamp: new Date(2).toISOString(), + }, + 'alpha', + ) + const marked = await markMessageAsReadByIdentity( + 'worker', + 'alpha', + permissionResponse, + ) + + const after = await readRawMailbox('worker', 'alpha') + expect(marked).toBe(true) + expect(after.some(m => m.text === permissionResponse.text && !m.read)).toBe( + false, + ) + }) + + test('markMessageAsReadByIndex also compacts through the compatibility path', async () => { + const existing = Array.from( + { length: MAX_MAILBOX_MESSAGES + 10 }, + (_value, index) => message(`msg-${index}`, false), + ) + await seedMailbox('worker', 'alpha', existing) + + await markMessageAsReadByIndex('worker', 'alpha', existing.length - 1) + + const after = await readRawMailbox('worker', 'alpha') + expect(after).toHaveLength(MAX_MAILBOX_MESSAGES) + expect(after.some(m => m.text === `msg-${existing.length - 1}`)).toBe(false) + expect(after.at(-1)?.text).toBe(`msg-${existing.length - 2}`) + }) + + test('writeToMailbox rejects oversized message text instead of storing it', async () => { + await expect( + writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'x'.repeat(MAX_MAILBOX_MESSAGE_TEXT_BYTES + 1), + timestamp: new Date(3).toISOString(), + }, + 'alpha', + ), + ).rejects.toThrow('Mailbox message text exceeds') + + expect(await readRawMailbox('worker', 'alpha')).toEqual([]) + }) + + test('writeToMailbox fails closed when an existing mailbox is corrupt', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, '{not-json', 'utf-8') + + await expect( + writeToMailbox( + 'worker', + { + from: 'team-lead', + text: 'new', + timestamp: new Date(4).toISOString(), + }, + 'alpha', + ), + ).rejects.toThrow() + + expect(await readFile(inboxPath, 'utf-8')).toBe('{not-json') + }) + + test('readMailbox fails closed on corrupt mailbox content', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, '{not-json', 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow() + }) +}) diff --git a/src/utils/__tests__/udsMessaging.test.ts b/src/utils/__tests__/udsMessaging.test.ts new file mode 100644 index 0000000000..52cb57abc6 --- /dev/null +++ b/src/utils/__tests__/udsMessaging.test.ts @@ -0,0 +1,305 @@ +import { afterEach, describe, expect, test } from 'bun:test' +import { chmod, mkdir, rm, stat, symlink, unlink } from 'node:fs/promises' +import { createConnection, createServer } from 'node:net' +import { dirname, join } from 'node:path' +import { tmpdir } from 'node:os' +import { + drainInbox, + MAX_UDS_INBOX_ENTRIES, + MAX_UDS_INBOX_BYTES, + MAX_UDS_FRAME_BYTES, + parseUdsTarget, + sendUdsMessage, + setOnEnqueue, + startUdsMessaging, + stopUdsMessaging, +} from '../udsMessaging.js' + +function socketPath(label: string): string { + const suffix = `${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}-${label}` + if (process.platform === 'win32') { + return `\\\\.\\pipe\\claude-code-test-${suffix}` + } + return join(tmpdir(), 'claude-code-test', `${suffix}.sock`) +} + +function sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)) +} + +async function waitForEnqueues( + expected: number, + sendMessages: () => Promise, +): Promise { + let count = 0 + let resolveDone: (() => void) | undefined + const done = new Promise(resolve => { + resolveDone = resolve + }) + + setOnEnqueue(() => { + count++ + if (count >= expected) resolveDone?.() + }) + + await sendMessages() + await Promise.race([ + done, + sleep(5_000).then(() => { + throw new Error(`Timed out waiting for ${expected} UDS enqueues`) + }), + ]) + setOnEnqueue(null) +} + +afterEach(async () => { + setOnEnqueue(null) + drainInbox() + await stopUdsMessaging() +}) + +async function closeServer(server: ReturnType): Promise { + await new Promise(resolve => { + server.close(() => resolve()) + }) +} + +describe('UDS inbox retention', () => { + test('drainInbox returns each pending socket message once', async () => { + const path = socketPath('drain') + await startUdsMessaging(path, { isExplicit: true }) + expect(process.env.CLAUDE_CODE_MESSAGING_TOKEN).toBeUndefined() + + await waitForEnqueues(2, async () => { + await sendUdsMessage(path, { type: 'text', data: 'one' }) + await sendUdsMessage(path, { type: 'text', data: 'two' }) + }) + + const drained = drainInbox() + expect(drained.map(entry => entry.message.data)).toEqual(['one', 'two']) + expect(drained.every(entry => entry.status === 'processed')).toBe(true) + expect(drainInbox()).toEqual([]) + }) + + test('inbox is capped when messages arrive faster than they are drained', async () => { + const path = socketPath('cap') + await startUdsMessaging(path, { isExplicit: true }) + + await waitForEnqueues(MAX_UDS_INBOX_ENTRIES, async () => { + for (let i = 0; i < MAX_UDS_INBOX_ENTRIES; i++) { + await sendUdsMessage(path, { type: 'text', data: String(i) }) + } + }) + await expect( + sendUdsMessage(path, { type: 'text', data: 'overflow' }), + ).rejects.toThrow('inbox full') + + const drained = drainInbox() + expect(drained).toHaveLength(MAX_UDS_INBOX_ENTRIES) + expect(drained[0]?.message.data).toBe('0') + expect(drained.at(-1)?.message.data).toBe(String(MAX_UDS_INBOX_ENTRIES - 1)) + }) + + test('inbox is capped by retained bytes before entry count', async () => { + const path = socketPath('byte-cap') + await startUdsMessaging(path, { isExplicit: true }) + + const payload = 'x'.repeat(32 * 1024) + let accepted = 0 + for (;;) { + try { + await sendUdsMessage(path, { type: 'text', data: payload }) + accepted++ + if (accepted > MAX_UDS_INBOX_BYTES / payload.length + 20) { + throw new Error('byte cap was not enforced') + } + } catch (error) { + expect(error).toBeInstanceOf(Error) + expect((error as Error).message).toContain('inbox full') + break + } + } + + const drained = drainInbox() + expect(drained.length).toBe(accepted) + expect(drained.length).toBeLessThan(MAX_UDS_INBOX_ENTRIES) + }) + + test('ping replies with pong without enqueueing inbox work', async () => { + const path = socketPath('ping') + await startUdsMessaging(path, { isExplicit: true }) + + await sendUdsMessage(path, { type: 'ping' }) + expect(drainInbox()).toEqual([]) + }) + + test('drained entries never expose the UDS auth token', async () => { + const path = socketPath('strip-token') + await startUdsMessaging(path, { isExplicit: true }) + + await waitForEnqueues(1, async () => { + await sendUdsMessage(path, { + type: 'notification', + meta: { keep: 'visible' }, + }) + }) + + const drained = drainInbox() + expect(drained).toHaveLength(1) + expect(drained[0]?.message.meta).toEqual({ keep: 'visible' }) + expect(drained[0]?.message.meta).not.toHaveProperty('authToken') + }) + + test('rejects unauthenticated socket messages', async () => { + const path = socketPath('auth') + await startUdsMessaging(path, { isExplicit: true }) + + const response = await new Promise((resolve, reject) => { + const conn = createConnection(path, () => { + conn.write(`${JSON.stringify({ type: 'text', data: 'bad' })}\n`) + }) + conn.setTimeout(5_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for auth rejection')) + }) + conn.on('data', chunk => { + const text = chunk.toString('utf-8') + if (text.includes('\n')) { + conn.end() + resolve(text) + } + }) + conn.on('error', reject) + }) + + expect(JSON.parse(response).type).toBe('error') + expect(drainInbox()).toEqual([]) + }) + + test('destroys oversized frames before enqueueing inbox work', async () => { + const path = socketPath('oversized') + await startUdsMessaging(path, { isExplicit: true }) + + await new Promise((resolve, reject) => { + const conn = createConnection(path, () => { + conn.write('x'.repeat(MAX_UDS_FRAME_BYTES + 1)) + }) + conn.setTimeout(5_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for oversized frame close')) + }) + conn.on('close', () => resolve()) + conn.on('error', () => resolve()) + }) + + expect(drainInbox()).toEqual([]) + }) + + test('rejects oversized receiver responses before retaining them', async () => { + const path = socketPath('oversized-response') + if (process.platform !== 'win32') { + await mkdir(dirname(path), { recursive: true }) + } + const receiver = createServer(socket => { + socket.on('data', () => { + socket.write('x'.repeat(MAX_UDS_FRAME_BYTES + 1)) + }) + }) + await new Promise((resolve, reject) => { + receiver.on('error', reject) + receiver.listen(path, () => resolve()) + }) + + try { + await expect( + sendUdsMessage( + path, + { type: 'text', data: 'hello' }, + { authToken: 'test-token' }, + ), + ).rejects.toThrow('UDS response frame exceeded size limit') + } finally { + await closeServer(receiver) + if (process.platform !== 'win32') { + await unlink(path).catch(() => undefined) + } + } + }) + + test('rejects inline auth token UDS targets instead of parsing them', async () => { + const path = socketPath('inline-token') + + const targetWithToken = `${path}#token=secret` + expect(() => parseUdsTarget(targetWithToken)).toThrow('inline auth token') + try { + parseUdsTarget(targetWithToken) + } catch (error) { + expect((error as Error).message).not.toContain('secret') + } + + const { sendToUdsSocket } = await import('../udsClient.js') + await expect(sendToUdsSocket(targetWithToken, 'hello')).rejects.toThrow( + 'inline auth token', + ) + }) + + if (process.platform !== 'win32') { + test('creates the listening socket with owner-only permissions', async () => { + const path = socketPath('socket-mode') + await startUdsMessaging(path, { isExplicit: true }) + + const mode = (await stat(path)).mode & 0o777 + expect(mode).toBe(0o600) + }) + + test('fails closed when the capability directory is not private', async () => { + const previousConfigDir = process.env.CLAUDE_CONFIG_DIR + const tempHome = join( + tmpdir(), + `uds-capability-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + process.env.CLAUDE_CONFIG_DIR = tempHome + const capabilityDir = join(tempHome, 'messaging-capabilities') + await mkdir(capabilityDir, { recursive: true, mode: 0o755 }) + await chmod(capabilityDir, 0o755) + + try { + await expect( + startUdsMessaging(socketPath('broad-capdir'), { isExplicit: true }), + ).rejects.toThrow('permissions are too broad') + } finally { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) + } + }) + + test('fails closed when the capability directory is a symlink', async () => { + const previousConfigDir = process.env.CLAUDE_CONFIG_DIR + const tempHome = join( + tmpdir(), + `uds-capability-link-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + const target = join(tempHome, 'target') + process.env.CLAUDE_CONFIG_DIR = tempHome + await mkdir(target, { recursive: true, mode: 0o700 }) + await symlink(target, join(tempHome, 'messaging-capabilities'), 'dir') + + try { + await expect( + startUdsMessaging(socketPath('symlink-capdir'), { isExplicit: true }), + ).rejects.toThrow('not a private directory') + } finally { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) + } + }) + } +}) diff --git a/src/utils/messages/systemInit.ts b/src/utils/messages/systemInit.ts index fcb9e74d11..4585c78172 100644 --- a/src/utils/messages/systemInit.ts +++ b/src/utils/messages/systemInit.ts @@ -87,8 +87,10 @@ export function buildSystemInitMessage(inputs: SystemInitInputs): SDKMessage { // Hidden from public SDK types — ant-only UDS messaging socket path if (feature('UDS_INBOX')) { /* eslint-disable @typescript-eslint/no-require-imports */ + const udsMessaging = + require('../udsMessaging.js') as typeof import('../udsMessaging.js') ;(initMessage as Record).messaging_socket_path = - require('../udsMessaging.js').getUdsMessagingSocketPath() + udsMessaging.getUdsMessagingSocketPath() /* eslint-enable @typescript-eslint/no-require-imports */ } initMessage.fast_mode_state = getFastModeState(inputs.model, inputs.fastMode) diff --git a/src/utils/ndjsonFramer.ts b/src/utils/ndjsonFramer.ts index 968ee5217f..ecaa04dc8c 100644 --- a/src/utils/ndjsonFramer.ts +++ b/src/utils/ndjsonFramer.ts @@ -7,6 +7,11 @@ */ import type { Socket } from 'net' +export type NdjsonFramerOptions = { + maxFrameBytes?: number + onFrameError?: (error: Error) => void +} + /** * Attach an NDJSON framer to a socket. Calls `onMessage` for each * complete JSON line received. Malformed lines are silently skipped. @@ -19,21 +24,54 @@ export function attachNdjsonFramer( socket: Socket, onMessage: (msg: T) => void, parse: (text: string) => T = text => JSON.parse(text) as T, + options: NdjsonFramerOptions = {}, ): void { let buffer = '' + const maxFrameBytes = options.maxFrameBytes ?? Number.POSITIVE_INFINITY + + const rejectOversizedFrame = (bytes: number): void => { + const error = new Error( + `NDJSON frame exceeded ${maxFrameBytes} bytes (${bytes})`, + ) + options.onFrameError?.(error) + socket.destroy(error) + } socket.on('data', (chunk: Buffer) => { + if ( + Number.isFinite(maxFrameBytes) && + !chunk.includes(0x0a) && + Buffer.byteLength(buffer, 'utf8') + chunk.byteLength > maxFrameBytes + ) { + rejectOversizedFrame(Buffer.byteLength(buffer, 'utf8') + chunk.byteLength) + return + } + buffer += chunk.toString() const lines = buffer.split('\n') buffer = lines.pop() ?? '' for (const line of lines) { if (!line.trim()) continue + if ( + Number.isFinite(maxFrameBytes) && + Buffer.byteLength(line, 'utf8') > maxFrameBytes + ) { + rejectOversizedFrame(Buffer.byteLength(line, 'utf8')) + return + } try { onMessage(parse(line)) } catch { // Malformed JSON — skip } } + + if ( + Number.isFinite(maxFrameBytes) && + Buffer.byteLength(buffer, 'utf8') > maxFrameBytes + ) { + rejectOversizedFrame(Buffer.byteLength(buffer, 'utf8')) + } }) } diff --git a/src/utils/swarm/inProcessRunner.ts b/src/utils/swarm/inProcessRunner.ts index 1735500b4d..eaab58ef76 100644 --- a/src/utils/swarm/inProcessRunner.ts +++ b/src/utils/swarm/inProcessRunner.ts @@ -97,7 +97,7 @@ import { getLastPeerDmSummary, isPermissionResponse, isShutdownRequest, - markMessageAsReadByIndex, + markMessageAsReadByIdentity, readMailbox, writeToMailbox, } from '../teammateMailbox.js' @@ -405,10 +405,10 @@ function createInProcessCanUseTool( if (msg && !msg.read) { const parsed = isPermissionResponse(msg.text) if (parsed && parsed.request_id === request.id) { - await markMessageAsReadByIndex( + await markMessageAsReadByIdentity( identity.agentName, identity.teamName, - i, + msg, ) if (parsed.subtype === 'success') { processMailboxPermissionResponse({ @@ -801,10 +801,10 @@ async function waitForNextPromptOrShutdown( logForDebugging( `[inProcessRunner] ${identity.agentName} received shutdown request from ${shutdownParsed?.from} (prioritized over ${skippedUnread} unread messages)`, ) - await markMessageAsReadByIndex( + await markMessageAsReadByIdentity( identity.agentName, identity.teamName, - shutdownIndex, + msg, ) return { type: 'shutdown_request', @@ -839,10 +839,10 @@ async function waitForNextPromptOrShutdown( logForDebugging( `[inProcessRunner] ${identity.agentName} received new message from ${msg.from} (index ${selectedIndex})`, ) - await markMessageAsReadByIndex( + await markMessageAsReadByIdentity( identity.agentName, identity.teamName, - selectedIndex, + msg, ) return { type: 'new_message', @@ -1246,8 +1246,13 @@ export async function runInProcessTeammate( // Track in-progress tool use IDs for animation in transcript view let inProgressToolUseIDs = task.inProgressToolUseIDs if (message.type === 'assistant') { - for (const block of (Array.isArray(message.message!.content) ? message.message!.content : [])) { - if (typeof block !== 'string' && block.type === 'tool_use') { + for (const block of Array.isArray(message.message!.content) + ? message.message!.content + : []) { + if ( + typeof block !== 'string' && + block.type === 'tool_use' + ) { inProgressToolUseIDs = new Set([ ...(inProgressToolUseIDs ?? []), block.id, @@ -1318,7 +1323,10 @@ export async function runInProcessTeammate( setAppState, ) if (currentAutonomyRunId) { - await markAutonomyRunFailed(currentAutonomyRunId, ERROR_MESSAGE_USER_ABORT) + await markAutonomyRunFailed( + currentAutonomyRunId, + ERROR_MESSAGE_USER_ABORT, + ) currentAutonomyRunId = undefined } } else if (currentAutonomyRunId) { diff --git a/src/utils/teammateMailbox.ts b/src/utils/teammateMailbox.ts index eb72fcc218..6c18fd7213 100644 --- a/src/utils/teammateMailbox.ts +++ b/src/utils/teammateMailbox.ts @@ -7,7 +7,8 @@ * Note: Inboxes are keyed by agent name within a team. */ -import { mkdir, readFile, writeFile } from 'fs/promises' +import { randomBytes } from 'crypto' +import { mkdir, readFile, rename, stat, writeFile } from 'fs/promises' import { join } from 'path' import { z } from 'zod/v4' import { TEAMMATE_MESSAGE_TAG } from '../constants/xml.js' @@ -40,6 +41,13 @@ const LOCK_OPTIONS = { }, } +export const MAX_MAILBOX_MESSAGES = 1_000 +export const MAX_READ_MAILBOX_MESSAGES = 200 +export const MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES = 2_000 +export const MAX_MAILBOX_MESSAGE_TEXT_BYTES = 64 * 1024 +export const MAX_MAILBOX_RETAINED_BYTES = 2 * 1024 * 1024 +export const MAX_MAILBOX_FILE_BYTES = 4 * 1024 * 1024 + export type TeammateMessage = { from: string text: string @@ -49,6 +57,218 @@ export type TeammateMessage = { summary?: string // 5-10 word summary shown as preview in the UI } +function isJsonLikeMessage(text: string): boolean { + const trimmed = text.trimStart() + return trimmed.startsWith('{') || trimmed.startsWith('[') +} + +function shouldRetainUnreadAsProtocolMessage( + message: TeammateMessage, +): boolean { + if (message.read) return false + if (isStructuredProtocolMessage(message.text)) return true + if (!isJsonLikeMessage(message.text)) return false + + try { + const parsed = jsonParse(message.text) + return Boolean( + parsed && + typeof parsed === 'object' && + 'type' in (parsed as Record), + ) + } catch { + return true + } +} + +function sameMailboxMessage(a: TeammateMessage, b: TeammateMessage): boolean { + return a.from === b.from && a.timestamp === b.timestamp && a.text === b.text +} + +function mailboxMessageStorageBytes(message: TeammateMessage): number { + return Buffer.byteLength(jsonStringify(message), 'utf8') +} + +function assertMailboxMessageSize(message: TeammateMessage): void { + const textBytes = Buffer.byteLength(message.text, 'utf8') + if (textBytes > MAX_MAILBOX_MESSAGE_TEXT_BYTES) { + throw new Error( + `Mailbox message text exceeds ${MAX_MAILBOX_MESSAGE_TEXT_BYTES} bytes`, + ) + } +} + +function toMailboxMessage(value: unknown): TeammateMessage { + if (!value || typeof value !== 'object') { + throw new Error('Invalid mailbox message: expected object') + } + const record = value as Record + if ( + typeof record.from !== 'string' || + typeof record.text !== 'string' || + typeof record.timestamp !== 'string' || + typeof record.read !== 'boolean' + ) { + throw new Error('Invalid mailbox message shape') + } + const message: TeammateMessage = { + from: record.from, + text: record.text, + timestamp: record.timestamp, + read: record.read, + ...(typeof record.color === 'string' ? { color: record.color } : {}), + ...(typeof record.summary === 'string' ? { summary: record.summary } : {}), + } + assertMailboxMessageSize(message) + return message +} + +function parseMailboxMessages(content: string): TeammateMessage[] { + const parsed = jsonParse(content) + if (!Array.isArray(parsed)) { + throw new Error('Invalid mailbox file: expected message array') + } + return parsed.map(toMailboxMessage) +} + +async function readMailboxFile(inboxPath: string): Promise { + const info = await stat(inboxPath) + if (info.size > MAX_MAILBOX_FILE_BYTES) { + throw new Error( + `Mailbox file exceeds ${MAX_MAILBOX_FILE_BYTES} bytes: ${inboxPath}`, + ) + } + return readFile(inboxPath, 'utf-8') +} + +async function readMailboxForMutation( + agentName: string, + teamName?: string, +): Promise { + const inboxPath = getInboxPath(agentName, teamName) + return parseMailboxMessages(await readMailboxFile(inboxPath)) +} + +async function writeMailboxAtomic( + inboxPath: string, + content: string, +): Promise { + const bytes = Buffer.byteLength(content, 'utf8') + if (bytes > MAX_MAILBOX_FILE_BYTES) { + throw new Error( + `Compacted mailbox still exceeds ${MAX_MAILBOX_FILE_BYTES} bytes`, + ) + } + const tempPath = `${inboxPath}.${process.pid}.${randomBytes(8).toString('hex')}.tmp` + await writeFile(tempPath, content, 'utf-8') + await rename(tempPath, inboxPath) +} + +export function compactMailboxMessages( + messages: TeammateMessage[], + limits: { + maxMessages?: number + maxReadMessages?: number + maxUnreadProtocolMessages?: number + maxRetainedBytes?: number + } = {}, +): TeammateMessage[] { + const maxMessages = limits.maxMessages ?? MAX_MAILBOX_MESSAGES + const maxReadMessages = limits.maxReadMessages ?? MAX_READ_MAILBOX_MESSAGES + const maxUnreadProtocolMessages = + limits.maxUnreadProtocolMessages ?? MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES + const maxRetainedBytes = limits.maxRetainedBytes ?? MAX_MAILBOX_RETAINED_BYTES + + if ( + maxRetainedBytes <= 0 || + (maxMessages <= 0 && maxUnreadProtocolMessages <= 0) + ) { + return [] + } + + const keepIndexes = new Set() + let retainedBytes = 0 + let keptUnreadProtocolMessages = 0 + const tryKeep = (index: number): boolean => { + if (keepIndexes.has(index)) return true + const message = messages[index] + if (!message) return false + const bytes = mailboxMessageStorageBytes(message) + if (bytes > maxRetainedBytes || retainedBytes + bytes > maxRetainedBytes) { + return false + } + keepIndexes.add(index) + retainedBytes += bytes + return true + } + + for (let i = messages.length - 1; i >= 0; i--) { + const message = messages[i] + if (!message || !shouldRetainUnreadAsProtocolMessage(message)) continue + if (keptUnreadProtocolMessages >= maxUnreadProtocolMessages) continue + if (tryKeep(i)) keptUnreadProtocolMessages++ + } + + let keptNonProtocolMessages = 0 + for (let i = messages.length - 1; i >= 0; i--) { + if (keptNonProtocolMessages >= maxMessages) break + const message = messages[i] + if ( + message && + !message.read && + !shouldRetainUnreadAsProtocolMessage(message) + ) { + if (tryKeep(i)) keptNonProtocolMessages++ + } + } + + let keptReadMessages = 0 + for (let i = messages.length - 1; i >= 0; i--) { + if (keptNonProtocolMessages >= maxMessages) break + if (keptReadMessages >= maxReadMessages) break + const message = messages[i] + if (message?.read) { + if (tryKeep(i)) { + keptReadMessages++ + keptNonProtocolMessages++ + } + } + } + + return messages.filter((_message, index) => keepIndexes.has(index)) +} + +function logUnreadMailboxEvictions( + original: TeammateMessage[], + compacted: TeammateMessage[], + context: string, +): void { + const kept = new Set(compacted) + const unreadEvicted = original.filter(message => { + return !message.read && !kept.has(message) + }) + if (unreadEvicted.length === 0) return + + const protocolEvicted = count(unreadEvicted, message => + shouldRetainUnreadAsProtocolMessage(message), + ) + logError( + new Error( + `[TeammateMailbox] Compacted ${unreadEvicted.length} unread message(s) in ${context}; protocol_or_unknown=${protocolEvicted}`, + ), + ) +} + +async function writeCompactedMailbox( + inboxPath: string, + messages: TeammateMessage[], + context: string, +): Promise { + const compacted = compactMailboxMessages(messages) + logUnreadMailboxEvictions(messages, compacted, context) + await writeMailboxAtomic(inboxPath, jsonStringify(compacted, null, 2)) +} + /** * Get the path to a teammate's inbox file * Structure: ~/.claude/teams/{team_name}/inboxes/{agent_name}.json @@ -89,8 +309,7 @@ export async function readMailbox( logForDebugging(`[TeammateMailbox] readMailbox: path=${inboxPath}`) try { - const content = await readFile(inboxPath, 'utf-8') - const messages = jsonParse(content) as TeammateMessage[] + const messages = parseMailboxMessages(await readMailboxFile(inboxPath)) logForDebugging( `[TeammateMailbox] readMailbox: read ${messages.length} message(s)`, ) @@ -103,7 +322,7 @@ export async function readMailbox( } logForDebugging(`Failed to read inbox for ${agentName}: ${error}`) logError(error) - return [] + throw error } } @@ -156,7 +375,7 @@ export async function writeToMailbox( `[TeammateMailbox] writeToMailbox: failed to create inbox file: ${error}`, ) logError(error) - return + throw error } } @@ -168,22 +387,23 @@ export async function writeToMailbox( }) // Re-read messages after acquiring lock to get the latest state - const messages = await readMailbox(recipientName, teamName) + const messages = await readMailboxForMutation(recipientName, teamName) - const newMessage: TeammateMessage = { + const newMessage = toMailboxMessage({ ...message, read: false, - } + }) messages.push(newMessage) - await writeFile(inboxPath, jsonStringify(messages, null, 2), 'utf-8') + await writeCompactedMailbox(inboxPath, messages, 'writeToMailbox') logForDebugging( `[TeammateMailbox] Wrote message to ${recipientName}'s inbox from ${message.from}`, ) } catch (error) { logForDebugging(`Failed to write to inbox for ${recipientName}: ${error}`) logError(error) + throw error } finally { if (release) { await release() @@ -222,7 +442,7 @@ export async function markMessageAsReadByIndex( logForDebugging(`[TeammateMailbox] markMessageAsReadByIndex: lock acquired`) // Re-read messages after acquiring lock to get the latest state - const messages = await readMailbox(agentName, teamName) + const messages = await readMailboxForMutation(agentName, teamName) logForDebugging( `[TeammateMailbox] markMessageAsReadByIndex: read ${messages.length} messages after lock`, ) @@ -244,7 +464,7 @@ export async function markMessageAsReadByIndex( messages[messageIndex] = { ...message, read: true } - await writeFile(inboxPath, jsonStringify(messages, null, 2), 'utf-8') + await writeCompactedMailbox(inboxPath, messages, 'markMessageAsReadByIndex') logForDebugging( `[TeammateMailbox] markMessageAsReadByIndex: marked message at index ${messageIndex} as read`, ) @@ -270,6 +490,46 @@ export async function markMessageAsReadByIndex( } } +export async function markMessageAsReadByIdentity( + agentName: string, + teamName: string | undefined, + expectedMessage: TeammateMessage, +): Promise { + const inboxPath = getInboxPath(agentName, teamName) + const lockFilePath = `${inboxPath}.lock` + + let release: (() => Promise) | undefined + try { + release = await lockfile.lock(inboxPath, { + lockfilePath: lockFilePath, + ...LOCK_OPTIONS, + }) + + const messages = await readMailboxForMutation(agentName, teamName) + const messageIndex = messages.findIndex(message => { + return !message.read && sameMailboxMessage(message, expectedMessage) + }) + if (messageIndex < 0) return false + + messages[messageIndex] = { ...messages[messageIndex]!, read: true } + await writeCompactedMailbox( + inboxPath, + messages, + 'markMessageAsReadByIdentity', + ) + return true + } catch (error) { + const code = getErrnoCode(error) + if (code === 'ENOENT') return false + logError(error) + return false + } finally { + if (release) { + await release() + } + } +} + /** * Mark all messages in a teammate's inbox as read * Uses file locking to prevent race conditions @@ -297,7 +557,7 @@ export async function markMessagesAsRead( logForDebugging(`[TeammateMailbox] markMessagesAsRead: lock acquired`) // Re-read messages after acquiring lock to get the latest state - const messages = await readMailbox(agentName, teamName) + const messages = await readMailboxForMutation(agentName, teamName) logForDebugging( `[TeammateMailbox] markMessagesAsRead: read ${messages.length} messages after lock`, ) @@ -317,7 +577,7 @@ export async function markMessagesAsRead( // messages comes from jsonParse — fresh, unshared objects safe to mutate for (const m of messages) m.read = true - await writeFile(inboxPath, jsonStringify(messages, null, 2), 'utf-8') + await writeCompactedMailbox(inboxPath, messages, 'markMessagesAsRead') logForDebugging( `[TeammateMailbox] markMessagesAsRead: WROTE ${unreadCount} message(s) as read to ${inboxPath}`, ) @@ -1114,7 +1374,7 @@ export async function markMessagesAsReadByPredicate( ...LOCK_OPTIONS, }) - const messages = await readMailbox(agentName, teamName) + const messages = await readMailboxForMutation(agentName, teamName) if (messages.length === 0) { return } @@ -1123,7 +1383,11 @@ export async function markMessagesAsReadByPredicate( !m.read && predicate(m) ? { ...m, read: true } : m, ) - await writeFile(inboxPath, jsonStringify(updatedMessages, null, 2), 'utf-8') + await writeCompactedMailbox( + inboxPath, + updatedMessages, + 'markMessagesAsReadByPredicate', + ) } catch (error) { const code = getErrnoCode(error) if (code === 'ENOENT') { @@ -1161,7 +1425,12 @@ export function getLastPeerDmSummary(messages: Message[]): string | undefined { if (!Array.isArray(content)) continue for (const block of content) { if (typeof block === 'string') continue - const b = block as unknown as { type: string; name?: string; input?: Record; [key: string]: unknown } + const b = block as unknown as { + type: string + name?: string + input?: Record + [key: string]: unknown + } if ( b.type === 'tool_use' && b.name === SEND_MESSAGE_TOOL_NAME && @@ -1177,7 +1446,7 @@ export function getLastPeerDmSummary(messages: Message[]): string | undefined { const to = b.input.to as string const summary = 'summary' in b.input && typeof b.input.summary === 'string' - ? b.input.summary as string + ? (b.input.summary as string) : (b.input.message as string).slice(0, 80) return `[to ${to}] ${summary}` } diff --git a/src/utils/udsClient.ts b/src/utils/udsClient.ts index 781f3ddd15..f08bee696b 100644 --- a/src/utils/udsClient.ts +++ b/src/utils/udsClient.ts @@ -16,7 +16,7 @@ import { errorMessage, isFsInaccessible } from './errors.js' import { isProcessRunning } from './genericProcessUtils.js' import { jsonParse, jsonStringify } from './slowOperations.js' import type { SessionKind } from './concurrentSessions.js' -import type { UdsMessage } from './udsMessaging.js' +import { MAX_UDS_FRAME_BYTES, type UdsMessage } from './udsMessaging.js' // --------------------------------------------------------------------------- // Types @@ -43,6 +43,12 @@ function getSessionsDir(): string { return join(getClaudeConfigHomeDir(), 'sessions') } +function getChunkBytes(chunk: string | Buffer): number { + return typeof chunk === 'string' + ? Buffer.byteLength(chunk, 'utf8') + : chunk.byteLength +} + // --------------------------------------------------------------------------- // Discovery // --------------------------------------------------------------------------- @@ -104,9 +110,14 @@ export async function listAllLiveSessions(): Promise { */ export async function listPeers(): Promise { const all = await listAllLiveSessions() - return all.filter( - s => s.pid !== process.pid && s.messagingSocketPath != null, - ) + return all.filter(s => s.pid !== process.pid && s.messagingSocketPath != null) +} + +async function findAuthTokenForSocketPath( + socketPath: string, +): Promise { + const { readUdsCapabilityToken } = await import('./udsMessaging.js') + return readUdsCapabilityToken(socketPath) } // --------------------------------------------------------------------------- @@ -117,10 +128,21 @@ export async function listPeers(): Promise { * Probe a UDS socket to check if a server is listening (ping/pong). * Returns true if the peer responds within the timeout. */ -export async function isPeerAlive(socketPath: string, timeoutMs = 3000): Promise { - return new Promise((resolve) => { +export async function isPeerAlive( + socketPath: string, + timeoutMs = 3000, + authToken?: string, +): Promise { + const token = authToken ?? (await findAuthTokenForSocketPath(socketPath)) + if (!token) return false + + return new Promise(resolve => { const conn = createConnection(socketPath, () => { - const ping: UdsMessage = { type: 'ping', ts: new Date().toISOString() } + const ping: UdsMessage = { + type: 'ping', + ts: new Date().toISOString(), + meta: { authToken: token }, + } conn.write(jsonStringify(ping) + '\n') }) @@ -135,7 +157,19 @@ export async function isPeerAlive(socketPath: string, timeoutMs = 3000): Promise }, timeoutMs) let buffer = '' - conn.on('data', (chunk) => { + conn.on('data', chunk => { + if ( + Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > + MAX_UDS_FRAME_BYTES + ) { + if (!resolved) { + resolved = true + clearTimeout(timer) + conn.destroy() + resolve(false) + } + return + } buffer += chunk.toString() if (buffer.includes('"pong"')) { if (!resolved) { @@ -165,6 +199,13 @@ export async function sendToUdsSocket( targetSocketPath: string, message: string | Record, ): Promise { + const { parseUdsTarget } = await import('./udsMessaging.js') + const target = parseUdsTarget(targetSocketPath) + const authToken = await findAuthTokenForSocketPath(target.socketPath) + if (!authToken) { + throw new Error(`No auth token found for peer at ${target.socketPath}`) + } + const data = typeof message === 'string' ? message : jsonStringify(message) const udsMsg: UdsMessage = { type: 'text', @@ -177,18 +218,59 @@ export async function sendToUdsSocket( udsMsg.from = getUdsMessagingSocketPath() return new Promise((resolve, reject) => { - const conn = createConnection(targetSocketPath, () => { - conn.write(jsonStringify(udsMsg) + '\n', (err) => { - conn.end() - if (err) reject(err) - else resolve() + let buffer = '' + let settled = false + const finish = (error?: Error): void => { + if (settled) return + settled = true + conn.end() + if (error) reject(error) + else resolve() + } + const conn = createConnection(target.socketPath, () => { + udsMsg.meta = { ...udsMsg.meta, authToken } + conn.write(jsonStringify(udsMsg) + '\n', err => { + if (err) finish(err) }) }) - conn.on('error', (err) => { - reject(new Error(`Failed to connect to peer at ${targetSocketPath}: ${errorMessage(err)}`)) + conn.on('data', chunk => { + if ( + Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > + MAX_UDS_FRAME_BYTES + ) { + finish(new Error('UDS response frame exceeded size limit')) + return + } + buffer += chunk.toString() + const lines = buffer.split('\n') + buffer = lines.pop() ?? '' + for (const line of lines) { + if (!line.trim()) continue + let response: UdsMessage + try { + response = jsonParse(line) as UdsMessage + } catch { + continue + } + if (response.type === 'response') { + finish() + return + } + if (response.type === 'error') { + finish(new Error(response.data ?? 'UDS receiver rejected message')) + return + } + } + }) + conn.on('error', err => { + finish( + new Error( + `Failed to connect to peer at ${target.socketPath}: ${errorMessage(err)}`, + ), + ) }) conn.setTimeout(5000, () => { - conn.destroy(new Error('Connection timed out')) + finish(new Error('Connection timed out')) }) }) } diff --git a/src/utils/udsMessaging.ts b/src/utils/udsMessaging.ts index 1c95ab63c2..7efa7fbf65 100644 --- a/src/utils/udsMessaging.ts +++ b/src/utils/udsMessaging.ts @@ -8,14 +8,25 @@ * but can be overridden via --messaging-socket-path. */ +import { createHash, randomBytes } from 'crypto' import { createServer, type Server, type Socket } from 'net' -import { mkdir, unlink } from 'fs/promises' +import { + chmod, + lstat, + mkdir, + open, + readFile, + rename, + unlink, +} from 'fs/promises' import { dirname, join } from 'path' import { tmpdir } from 'os' import { registerCleanup } from './cleanupRegistry.js' import { logForDebugging } from './debug.js' import { errorMessage } from './errors.js' +import { getClaudeConfigHomeDir } from './envUtils.js' import { attachNdjsonFramer } from './ndjsonFramer.js' +import { logError } from './log.js' import { jsonParse, jsonStringify } from './slowOperations.js' // --------------------------------------------------------------------------- @@ -27,6 +38,7 @@ export type UdsMessageType = | 'notification' | 'query' | 'response' + | 'error' | 'ping' | 'pong' @@ -60,6 +72,15 @@ let onEnqueueCb: (() => void) | null = null const clients = new Set() const inbox: UdsInboxEntry[] = [] let nextId = 1 +let defaultSocketPath: string | null = null +let authToken: string | null = null +let capabilityFilePath: string | null = null +let inboxBytes = 0 + +export const MAX_UDS_INBOX_ENTRIES = 1_000 +export const MAX_UDS_FRAME_BYTES = 64 * 1024 +export const MAX_UDS_INBOX_BYTES = 2 * 1024 * 1024 +export const MAX_UDS_CLIENTS = 128 // --------------------------------------------------------------------------- // Public API — socket path helpers @@ -74,10 +95,19 @@ let nextId = 1 * transparently, but we use the pipe format on Windows for Node.js compat. */ export function getDefaultUdsSocketPath(): string { + if (defaultSocketPath) return defaultSocketPath + const nonce = randomBytes(16).toString('hex') if (process.platform === 'win32') { - return `\\\\.\\pipe\\claude-code-${process.pid}` + defaultSocketPath = `\\\\.\\pipe\\claude-code-${process.pid}-${nonce}` + return defaultSocketPath } - return join(tmpdir(), 'claude-code-socks', `${process.pid}.sock`) + defaultSocketPath = join( + tmpdir(), + 'claude-code-socks', + `${process.pid}-${nonce}`, + 'messaging.sock', + ) + return defaultSocketPath } /** @@ -88,6 +118,142 @@ export function getUdsMessagingSocketPath(): string | undefined { return socketPath ?? undefined } +export function formatUdsAddress(socket: string): string { + return `uds:${socket}` +} + +export function parseUdsTarget(target: string): { + socketPath: string +} { + if (target.includes('#token=')) { + throw new Error( + 'UDS target must not include an inline auth token; use the ListPeers address', + ) + } + return { socketPath: target } +} + +function getCapabilityDir(): string { + return join(getClaudeConfigHomeDir(), 'messaging-capabilities') +} + +function getCapabilityPath(socket: string): string { + const digest = createHash('sha256').update(socket).digest('hex') + return join(getCapabilityDir(), `${digest}.json`) +} + +function isNotFound(error: unknown): boolean { + return ( + typeof error === 'object' && + error !== null && + (error as NodeJS.ErrnoException).code === 'ENOENT' + ) +} + +async function assertPrivateCapabilityDir(dir: string): Promise { + let stat: Awaited> + try { + stat = await lstat(dir) + } catch (error) { + if (!isNotFound(error)) throw error + await mkdir(dir, { recursive: true, mode: 0o700 }) + stat = await lstat(dir) + } + + if (!stat.isDirectory() || stat.isSymbolicLink()) { + throw new Error( + `[udsMessaging] capability directory is not a private directory: ${dir}`, + ) + } + if (process.platform !== 'win32') { + const broadMode = stat.mode & 0o077 + if (broadMode !== 0) { + throw new Error( + `[udsMessaging] capability directory permissions are too broad: ${dir}`, + ) + } + if (typeof process.getuid === 'function' && stat.uid !== process.getuid()) { + throw new Error( + `[udsMessaging] capability directory owner does not match current user: ${dir}`, + ) + } + } + + await chmod(dir, 0o700) +} + +async function writePrivateFileExclusive( + path: string, + content: string, +): Promise { + const handle = await open(path, 'wx', 0o600) + try { + await handle.writeFile(content, 'utf-8') + } finally { + await handle.close() + } + await chmod(path, 0o600) +} + +async function ensureSocketParent(path: string): Promise { + const dir = dirname(path) + try { + const stat = await lstat(dir) + if (!stat.isDirectory() || stat.isSymbolicLink()) { + throw new Error( + `[udsMessaging] socket parent is not a directory: ${dir}`, + ) + } + return + } catch (error) { + if (!isNotFound(error)) throw error + } + + await mkdir(dir, { recursive: true, mode: 0o700 }) + await chmod(dir, 0o700) +} + +async function writeCapabilityFile( + socket: string, + token: string, +): Promise { + const dir = getCapabilityDir() + await assertPrivateCapabilityDir(dir) + const target = getCapabilityPath(socket) + const temp = `${target}.${process.pid}.${randomBytes(8).toString('hex')}.tmp` + try { + await writePrivateFileExclusive( + temp, + jsonStringify({ socketPath: socket, authToken: token }), + ) + await rename(temp, target) + } catch (error) { + try { + await unlink(temp) + } catch { + // Temp file may not exist if exclusive creation failed. + } + throw error + } + capabilityFilePath = target +} + +export async function readUdsCapabilityToken( + socket: string, +): Promise { + try { + const parsed = jsonParse( + await readFile(getCapabilityPath(socket), 'utf-8'), + ) as Record + if (parsed.socketPath === socket && typeof parsed.authToken === 'string') { + return parsed.authToken + } + } catch { + // Missing or unreadable capability file means the peer is not addressable. + } + return undefined +} + // --------------------------------------------------------------------------- // Inbox // --------------------------------------------------------------------------- @@ -101,16 +267,79 @@ export function setOnEnqueue(cb: (() => void) | null): void { } /** - * Drain all pending inbox messages, marking them processed. + * Drain all pending inbox messages and release retained history. */ export function drainInbox(): UdsInboxEntry[] { - const pending = inbox.filter(e => e.status === 'pending') + const pending = inbox.splice(0, inbox.length) + inboxBytes = 0 for (const entry of pending) { entry.status = 'processed' } return pending } +function getMessageBytes(message: UdsMessage): number { + return Buffer.byteLength(jsonStringify(message), 'utf8') +} + +function enqueueInboxEntry(entry: UdsInboxEntry): boolean { + const entryBytes = getMessageBytes(entry.message) + if ( + entryBytes > MAX_UDS_FRAME_BYTES || + inbox.length >= MAX_UDS_INBOX_ENTRIES || + inboxBytes + entryBytes > MAX_UDS_INBOX_BYTES + ) { + logError( + new Error( + `[udsMessaging] inbox full (${inbox.length}/${MAX_UDS_INBOX_ENTRIES}, ${inboxBytes}/${MAX_UDS_INBOX_BYTES} bytes); dropping message type=${entry.message.type}`, + ), + ) + return false + } + inbox.push(entry) + inboxBytes += entryBytes + return true +} + +function ensureAuthToken(): string { + if (!authToken) { + authToken = randomBytes(32).toString('hex') + } + return authToken +} + +function getMessageAuthToken(message: UdsMessage): string | undefined { + const token = message.meta?.authToken + return typeof token === 'string' ? token : undefined +} + +function isAuthorizedMessage(message: UdsMessage): boolean { + return getMessageAuthToken(message) === authToken +} + +function writeSocketMessage(socket: Socket, message: UdsMessage): void { + if (socket.destroyed) return + socket.write(jsonStringify(message) + '\n') +} + +function stripAuthToken(message: UdsMessage): UdsMessage { + const { authToken: _authToken, ...metaWithoutAuth } = message.meta ?? {} + return { + ...message, + meta: Object.keys(metaWithoutAuth).length > 0 ? metaWithoutAuth : undefined, + } +} + +function withRequestAuthToken(message: UdsMessage, token: string): UdsMessage { + return { + ...message, + meta: { + ...message.meta, + authToken: token, + }, + } +} + // --------------------------------------------------------------------------- // Server // --------------------------------------------------------------------------- @@ -132,7 +361,7 @@ export async function startUdsMessaging( // Ensure parent directory exists (skip on Windows — pipe paths aren't files) if (process.platform !== 'win32') { - await mkdir(dirname(path), { recursive: true }) + await ensureSocketParent(path) } // Clean up stale socket file (skip on Windows — pipe paths aren't files) @@ -144,69 +373,134 @@ export async function startUdsMessaging( } } - socketPath = path - - await new Promise((resolve, reject) => { - const srv = createServer(socket => { - clients.add(socket) - logForDebugging( - `[udsMessaging] client connected (total: ${clients.size})`, - ) - - attachNdjsonFramer( - socket, - msg => { - // Handle ping with automatic pong - if (msg.type === 'ping') { - const pong: UdsMessage = { - type: 'pong', - from: socketPath ?? undefined, - ts: new Date().toISOString(), - } - if (!socket.destroyed) { - socket.write(jsonStringify(pong) + '\n') - } - return - } + const token = ensureAuthToken() + try { + await writeCapabilityFile(path, token) + socketPath = path - // Enqueue into inbox - const entry: UdsInboxEntry = { - id: `uds-${nextId++}`, - message: msg, - receivedAt: Date.now(), - status: 'pending', - } - inbox.push(entry) + await new Promise((resolve, reject) => { + const srv = createServer(socket => { + if (clients.size >= MAX_UDS_CLIENTS) { logForDebugging( - `[udsMessaging] enqueued message type=${msg.type} from=${msg.from ?? 'unknown'}`, + `[udsMessaging] rejected client: ${clients.size}/${MAX_UDS_CLIENTS} clients already connected`, ) - onEnqueueCb?.() - }, - text => jsonParse(text) as UdsMessage, - ) + socket.destroy() + return + } + clients.add(socket) + logForDebugging( + `[udsMessaging] client connected (total: ${clients.size})`, + ) + + attachNdjsonFramer( + socket, + msg => { + if (!isAuthorizedMessage(msg)) { + logForDebugging( + `[udsMessaging] rejected unauthenticated message type=${msg.type}`, + ) + if (!socket.destroyed) { + socket.write( + jsonStringify({ + type: 'error', + data: 'unauthorized', + ts: new Date().toISOString(), + } satisfies UdsMessage) + '\n', + ) + } + return + } - socket.on('close', () => { - clients.delete(socket) - }) + // Handle ping with automatic pong + if (msg.type === 'ping') { + writeSocketMessage(socket, { + type: 'pong', + from: socketPath ?? undefined, + ts: new Date().toISOString(), + }) + return + } - socket.on('error', err => { - clients.delete(socket) - logForDebugging(`[udsMessaging] client error: ${errorMessage(err)}`) + // Enqueue into inbox + const sanitizedMessage = stripAuthToken(msg) + const entry: UdsInboxEntry = { + id: `uds-${nextId++}`, + message: sanitizedMessage, + receivedAt: Date.now(), + status: 'pending', + } + if (!enqueueInboxEntry(entry)) { + writeSocketMessage(socket, { + type: 'error', + data: 'inbox full', + ts: new Date().toISOString(), + }) + return + } + logForDebugging( + `[udsMessaging] enqueued message type=${msg.type} from=${msg.from ?? 'unknown'}`, + ) + writeSocketMessage(socket, { + type: 'response', + data: 'ok', + ts: new Date().toISOString(), + meta: { id: entry.id }, + }) + onEnqueueCb?.() + }, + text => jsonParse(text) as UdsMessage, + { + maxFrameBytes: MAX_UDS_FRAME_BYTES, + onFrameError: error => { + logForDebugging(`[udsMessaging] ${error.message}`) + }, + }, + ) + + socket.on('close', () => { + clients.delete(socket) + }) + + socket.on('error', err => { + clients.delete(socket) + logForDebugging(`[udsMessaging] client error: ${errorMessage(err)}`) + }) }) - }) - srv.on('error', reject) + srv.on('error', reject) - srv.listen(path, () => { - server = srv - // Export so child processes can discover the socket - process.env.CLAUDE_CODE_MESSAGING_SOCKET = path - logForDebugging( - `[udsMessaging] server listening on ${path}${opts?.isExplicit ? ' (explicit)' : ''}`, - ) - resolve() + srv.listen(path, () => { + void (async () => { + try { + if (process.platform !== 'win32') { + await chmod(path, 0o600) + } + server = srv + // Export so child processes can discover the socket + process.env.CLAUDE_CODE_MESSAGING_SOCKET = path + logForDebugging( + `[udsMessaging] server listening on ${path}${opts?.isExplicit ? ' (explicit)' : ''}`, + ) + resolve() + } catch (error) { + srv.close(() => reject(error)) + } + })() + }) }) - }) + } catch (error) { + if (capabilityFilePath) { + try { + await unlink(capabilityFilePath) + } catch { + // Already gone. + } + capabilityFilePath = null + } + socketPath = null + authToken = null + throw error + } // Register cleanup so the socket file is removed on exit registerCleanup(async () => { @@ -230,6 +524,9 @@ export async function stopUdsMessaging(): Promise { server!.close(() => resolve()) }) server = null + inbox.length = 0 + inboxBytes = 0 + onEnqueueCb = null // Remove socket file (skip on Windows — pipe paths aren't files) if (socketPath) { @@ -245,9 +542,32 @@ export async function stopUdsMessaging(): Promise { `[udsMessaging] server stopped, socket removed: ${socketPath}`, ) socketPath = null + authToken = null + } + if (capabilityFilePath) { + try { + await unlink(capabilityFilePath) + } catch { + // Already gone + } + capabilityFilePath = null + } +} + +function parseResponseLine(line: string): UdsMessage | null { + try { + return jsonParse(line) as UdsMessage + } catch { + return null } } +function getChunkBytes(chunk: string | Buffer): number { + return typeof chunk === 'string' + ? Buffer.byteLength(chunk, 'utf8') + : chunk.byteLength +} + /** * Send a UDS message to a specific socket path (outbound — used when this * session wants to push a message to a peer's server). @@ -255,23 +575,66 @@ export async function stopUdsMessaging(): Promise { export async function sendUdsMessage( targetSocketPath: string, message: UdsMessage, + opts: { authToken?: string } = {}, ): Promise { const { createConnection } = await import('net') - message.from = message.from ?? socketPath ?? undefined - message.ts = message.ts ?? new Date().toISOString() + const token = opts.authToken ?? authToken + if (!token) { + throw new Error('Cannot send UDS message without auth token') + } + const outbound = withRequestAuthToken( + { + ...message, + from: message.from ?? socketPath ?? undefined, + ts: message.ts ?? new Date().toISOString(), + }, + token, + ) return new Promise((resolve, reject) => { + let buffer = '' + let settled = false + const finish = (error?: Error): void => { + if (settled) return + settled = true + conn.end() + if (error) reject(error) + else resolve() + } const conn = createConnection(targetSocketPath, () => { - conn.write(jsonStringify(message) + '\n', err => { - conn.end() - if (err) reject(err) - else resolve() + conn.write(jsonStringify(outbound) + '\n', err => { + if (err) finish(err) }) }) - conn.on('error', reject) + conn.on('data', chunk => { + if ( + Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > + MAX_UDS_FRAME_BYTES + ) { + finish(new Error('UDS response frame exceeded size limit')) + return + } + buffer += chunk.toString() + const lines = buffer.split('\n') + buffer = lines.pop() ?? '' + for (const line of lines) { + if (!line.trim()) continue + const response = parseResponseLine(line) + if (!response) continue + if (response.type === 'response' || response.type === 'pong') { + finish() + return + } + if (response.type === 'error') { + finish(new Error(response.data ?? 'UDS receiver rejected message')) + return + } + } + }) + conn.on('error', err => finish(err)) // Timeout so we don't hang on unreachable sockets conn.setTimeout(5000, () => { - conn.destroy(new Error('Connection timed out')) + finish(new Error('Connection timed out')) }) }) } From ee0d788e5823a8cf43341e582de19b794dcdc661 Mon Sep 17 00:00:00 2001 From: unraid Date: Mon, 27 Apr 2026 10:32:18 +0800 Subject: [PATCH 2/3] fix: harden bounded agent communication review fixes CodeRabbit and Codecov surfaced real gaps in UDS framing, peer discovery, mailbox retention, and summary context coverage. This tightens those paths without suppressing review or coverage signals. Constraint: PR #369 must address CodeRabbit and Codecov findings without warning suppression or fake fallbacks Rejected: Suppress Codecov or CodeRabbit warnings | leaves real receive-path and test-isolation gaps Rejected: Add unreachable feature-gated tests | bun:bundle keeps those branches compile-time gated in local tests Confidence: high Scope-risk: moderate Directive: Keep UDS auth-token rejection outside feature flags; do not reintroduce inline token fallbacks Tested: bun test --coverage --coverage-reporter lcov --coverage-dir coverage; bun run test:all; bun run lint; bun run build; bun run build:vite; bun audit; git diff --cached --check Not-tested: Remote Codecov/CodeRabbit refreshed reports until pushed --- .../src/tools/ListPeersTool/ListPeersTool.ts | 16 ++- .../tools/SendMessageTool/SendMessageTool.ts | 40 ++++-- .../udsRecipientSanitization.test.ts | 82 ++++++++++- src/bridge/peerSessions.ts | 32 +++++ src/cli/print.ts | 2 +- .../__tests__/agentSummary.test.ts | 133 ++++++++++++++++++ .../__tests__/summaryContext.test.ts | 17 +++ src/services/AgentSummary/summaryContext.ts | 9 +- src/utils/__tests__/teammateMailbox.test.ts | 70 +++++++++ src/utils/__tests__/udsMessaging.test.ts | 100 ++++++++++++- src/utils/ndjsonFramer.ts | 52 ++++--- src/utils/teammateMailbox.ts | 13 +- src/utils/udsClient.ts | 58 ++------ src/utils/udsMessaging.ts | 93 ++++++------ src/utils/udsResponseReader.ts | 81 +++++++++++ 15 files changed, 651 insertions(+), 147 deletions(-) create mode 100644 src/services/AgentSummary/__tests__/agentSummary.test.ts create mode 100644 src/utils/udsResponseReader.ts diff --git a/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts b/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts index 48219edc72..f6c52ea80e 100644 --- a/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts +++ b/packages/builtin-tools/src/tools/ListPeersTool/ListPeersTool.ts @@ -84,19 +84,27 @@ Use this tool to discover messaging targets before sending cross-session message // UDS socket directory. The implementation scans for live sockets // and optionally includes Remote Control bridge peers. const peers: PeerInfo[] = [] + const seen = new Set() + const addPeer = (peer: PeerInfo): void => { + if (seen.has(peer.address)) return + seen.add(peer.address) + peers.push(peer) + } /* eslint-disable @typescript-eslint/no-require-imports */ const udsMessaging = require('src/utils/udsMessaging.js') as typeof import('src/utils/udsMessaging.js') const udsClient = require('src/utils/udsClient.js') as typeof import('src/utils/udsClient.js') + const bridgePeers = + require('src/bridge/peerSessions.js') as typeof import('src/bridge/peerSessions.js') /* eslint-enable @typescript-eslint/no-require-imports */ const messagingSocketPath = udsMessaging.getUdsMessagingSocketPath() if (messagingSocketPath) { // Self entry for reference if (_input.include_self) { - peers.push({ + addPeer({ address: udsMessaging.formatUdsAddress(messagingSocketPath), name: 'self', pid: process.pid, @@ -106,7 +114,7 @@ Use this tool to discover messaging targets before sending cross-session message for (const peer of await udsClient.listPeers()) { if (!peer.messagingSocketPath) continue - peers.push({ + addPeer({ address: udsMessaging.formatUdsAddress(peer.messagingSocketPath), name: peer.name ?? peer.kind, cwd: peer.cwd, @@ -114,6 +122,10 @@ Use this tool to discover messaging targets before sending cross-session message }) } + for (const peer of await bridgePeers.listBridgePeers()) { + addPeer(peer) + } + return { data: { peers }, } diff --git a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts index 3548544fca..e4868bc531 100644 --- a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts +++ b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts @@ -672,17 +672,15 @@ export const SendMessageTool: Tool = errorCode: 9, } } - if (feature('UDS_INBOX')) { - if ( - addr.scheme === 'uds' && - (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) - ) { - return { - result: false, - message: - 'uds addresses must not include inline auth tokens; use the ListPeers address', - errorCode: 9, - } + if ( + addr.scheme === 'uds' && + (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) + ) { + return { + result: false, + message: + 'uds addresses must not include inline auth tokens; use the ListPeers address', + errorCode: 9, } } if (input.to.includes('@')) { @@ -808,6 +806,22 @@ export const SendMessageTool: Tool = }, async call(input, context, canUseTool, assistantMessage) { + if (typeof input.message === 'string') { + const addr = parseAddress(input.to) + if ( + addr.scheme === 'uds' && + (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) + ) { + return { + data: { + success: false, + message: + 'uds addresses must not include inline auth tokens; use the ListPeers address', + }, + } + } + } + if (feature('UDS_INBOX') && typeof input.message === 'string') { const addr = parseAddress(input.to) if (addr.scheme === 'bridge') { @@ -827,10 +841,10 @@ export const SendMessageTool: Tool = const { postInterClaudeMessage } = require('src/bridge/peerSessions.js') as typeof import('src/bridge/peerSessions.js') /* eslint-enable @typescript-eslint/no-require-imports */ - const result = await postInterClaudeMessage( + const result = (await postInterClaudeMessage( addr.target, input.message, - ) as { ok: boolean; error?: string } + )) as { ok: boolean; error?: string } const preview = input.summary || truncate(input.message, 50) return { data: { diff --git a/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts index 20124b6c38..a0ab2af0dc 100644 --- a/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts +++ b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts @@ -1,8 +1,4 @@ -import { describe, expect, mock, test } from 'bun:test' - -mock.module('bun:bundle', () => ({ - feature: (name: string) => name === 'UDS_INBOX', -})) +import { describe, expect, test } from 'bun:test' describe('SendMessageTool UDS recipient handling', () => { test('redacts inline UDS tokens before classifier and observable paths', async () => { @@ -25,6 +21,62 @@ describe('SendMessageTool UDS recipient handling', () => { ).toBe('to uds:/tmp/peer.sock: hello') }) + test('keeps redacted UDS token rejection through observable backfill', async () => { + const { SendMessageTool } = await import('../SendMessageTool.js') + const observableInput = { + to: 'uds:/tmp/peer.sock#token=secret-token', + message: { + type: 'plan_approval_response', + request_id: 'req-1', + approve: false, + reason: 'needs tests', + }, + } as Record + + SendMessageTool.backfillObservableInput!(observableInput) + + expect(observableInput.to).toBe('uds:/tmp/peer.sock') + expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') + expect(observableInput.type).toBe('plan_approval_response') + expect(observableInput.request_id).toBe('req-1') + expect(observableInput.approve).toBe(false) + expect(observableInput.content).toBe('needs tests') + expect(JSON.stringify(observableInput)).not.toContain('secret-token') + + const result = await SendMessageTool.validateInput!( + observableInput as never, + {} as never, + ) + + expect(result.result).toBe(false) + if (result.result !== false) { + throw new Error('expected validation to reject redacted inline UDS token') + } + expect(result.message).toContain('inline auth tokens') + }) + + test('redacts UDS tokens in structured classifier text', async () => { + const { SendMessageTool } = await import('../SendMessageTool.js') + const to = 'uds:/tmp/peer.sock#token=secret-token' + + expect( + SendMessageTool.toAutoClassifierInput({ + to, + message: { type: 'shutdown_request' }, + }), + ).toBe('shutdown_request to uds:/tmp/peer.sock') + expect( + SendMessageTool.toAutoClassifierInput({ + to, + message: { + type: 'plan_approval_response', + request_id: 'req-1', + approve: true, + }, + }), + ).toBe('plan_approval approve to uds:/tmp/peer.sock') + }) + test('rejects inline UDS tokens during validation', async () => { const { SendMessageTool } = await import('../SendMessageTool.js') const result = await SendMessageTool.validateInput!( @@ -36,6 +88,26 @@ describe('SendMessageTool UDS recipient handling', () => { ) expect(result.result).toBe(false) + if (result.result !== false) { + throw new Error('expected validation to reject inline UDS token') + } + expect(result.message).toContain('inline auth tokens') + expect(JSON.stringify(result)).not.toContain('secret-token') + }) + + test('rejects inline UDS tokens during execution without leaking them', async () => { + const { SendMessageTool } = await import('../SendMessageTool.js') + const result = await SendMessageTool.call( + { + to: 'uds:/tmp/peer.sock#token=secret-token', + message: 'hello', + }, + {} as never, + undefined as never, + undefined as never, + ) + + expect(result.data.success).toBe(false) expect(JSON.stringify(result)).not.toContain('secret-token') }) }) diff --git a/src/bridge/peerSessions.ts b/src/bridge/peerSessions.ts index c194c9b624..716c879deb 100644 --- a/src/bridge/peerSessions.ts +++ b/src/bridge/peerSessions.ts @@ -6,6 +6,38 @@ import { getBridgeAccessToken } from './bridgeConfig.js' import { getReplBridgeHandle } from './replBridgeHandle.js' import { toCompatSessionId } from './sessionIdCompat.js' +export type BridgePeerSession = { + address: string + name?: string + cwd?: string + pid?: number +} + +/** + * List locally registered sessions that have published a Remote Control + * session ID. The PID registry is the local source of truth for bridge peers + * already known to this machine; SendMessage can use these bridge: + * addresses when the current process has an active bridge handle. + */ +export async function listBridgePeers(): Promise { + const { listAllLiveSessions } = await import('../utils/udsClient.js') + const sessions = await listAllLiveSessions() + const peers: BridgePeerSession[] = [] + + for (const session of sessions) { + if (session.pid === process.pid || !session.bridgeSessionId) continue + const compatId = toCompatSessionId(session.bridgeSessionId) + peers.push({ + address: `bridge:${compatId}`, + name: session.name ?? session.kind, + cwd: session.cwd, + pid: session.pid, + }) + } + + return peers +} + /** * Send a plain-text message to another Claude session via the bridge API. * diff --git a/src/cli/print.ts b/src/cli/print.ts index 7a291fb504..c4e8c45697 100644 --- a/src/cli/print.ts +++ b/src/cli/print.ts @@ -2773,7 +2773,7 @@ function runHeadlessStreaming( const value = typeof entry.message.data === 'string' ? entry.message.data - : jsonStringify(entry.message) + : jsonStringify(entry.message.data) enqueue({ mode: 'prompt', value, diff --git a/src/services/AgentSummary/__tests__/agentSummary.test.ts b/src/services/AgentSummary/__tests__/agentSummary.test.ts new file mode 100644 index 0000000000..0ab0700800 --- /dev/null +++ b/src/services/AgentSummary/__tests__/agentSummary.test.ts @@ -0,0 +1,133 @@ +import { + afterAll, + afterEach, + beforeEach, + describe, + expect, + mock, + test, +} from 'bun:test' +import { debugMock } from '../../../../tests/mocks/debug' +import { logMock } from '../../../../tests/mocks/log' +import { asAgentId } from '../../../types/ids.js' +import type { CacheSafeParams } from '../../../utils/forkedAgent.js' + +const transcriptMessages = [ + { type: 'user', message: { content: 'start' }, uuid: 'u1' }, + { + type: 'assistant', + message: { content: [{ type: 'text', text: 'working' }] }, + uuid: 'a1', + }, + { type: 'user', message: { content: 'continue' }, uuid: 'u2' }, +] + +let poorModeActive = false +let forkCalls = 0 +let updateCalls: Array<{ taskId: string; summary: string }> = [] +let transcript = { messages: transcriptMessages } +const sessionStorageSnapshot = { + ...(require('../../../utils/sessionStorage.ts') as Record), +} + +mock.module('src/commands/poor/poorMode.js', () => ({ + isPoorModeActive: () => poorModeActive, +})) + +mock.module('src/tasks/LocalAgentTask/LocalAgentTask.js', () => ({ + updateAgentSummary: (taskId: string, summary: string) => { + updateCalls.push({ taskId, summary }) + }, +})) + +mock.module( + '@claude-code-best/builtin-tools/tools/AgentTool/runAgent.js', + () => ({ + filterIncompleteToolCalls: (messages: T) => messages, + }), +) + +mock.module('src/utils/debug.js', debugMock) +mock.module('src/utils/log.js', logMock) + +mock.module('src/utils/forkedAgent.js', () => ({ + runForkedAgent: async () => { + forkCalls += 1 + return { + messages: [ + { + type: 'assistant', + message: { + content: [{ type: 'text', text: 'Reading udsClient.ts' }], + }, + }, + ], + } + }, +})) + +mock.module('src/utils/sessionStorage.js', () => ({ + ...sessionStorageSnapshot, + getAgentTranscript: async () => transcript, +})) + +afterAll(() => { + mock.module('src/utils/sessionStorage.js', () => + require('../../../utils/sessionStorage.ts'), + ) +}) + +describe('startAgentSummarization', () => { + const realSetTimeout = globalThis.setTimeout + const realClearTimeout = globalThis.clearTimeout + let scheduled: + | ((...args: Parameters void)>) => void) + | undefined + + beforeEach(() => { + poorModeActive = false + forkCalls = 0 + updateCalls = [] + transcript = { messages: transcriptMessages } + scheduled = undefined + globalThis.setTimeout = ((callback: TimerHandler) => { + scheduled = callback as (...args: unknown[]) => void + return 1 as unknown as ReturnType + }) as unknown as typeof setTimeout + globalThis.clearTimeout = (() => undefined) as typeof clearTimeout + }) + + afterEach(() => { + globalThis.setTimeout = realSetTimeout + globalThis.clearTimeout = realClearTimeout + }) + + test('summarizes bounded transcript once and skips unchanged fingerprints', async () => { + const { startAgentSummarization } = await import('../agentSummary.js') + + const handle = startAgentSummarization( + 'task-1', + asAgentId('a0000000000000000'), + { + forkContextMessages: [{ type: 'user', message: { content: 'old' } }], + model: 'claude-test', + } as unknown as CacheSafeParams, + () => undefined, + ) + + expect(typeof scheduled).toBe('function') + await scheduled!() + + expect(forkCalls).toBe(1) + expect(updateCalls).toEqual([ + { taskId: 'task-1', summary: 'Reading udsClient.ts' }, + ]) + + await scheduled!() + + expect(forkCalls).toBe(1) + expect(updateCalls).toHaveLength(1) + + handle.stop() + }) +}) diff --git a/src/services/AgentSummary/__tests__/summaryContext.test.ts b/src/services/AgentSummary/__tests__/summaryContext.test.ts index 3ffa559645..1c701d14b4 100644 --- a/src/services/AgentSummary/__tests__/summaryContext.test.ts +++ b/src/services/AgentSummary/__tests__/summaryContext.test.ts @@ -2,6 +2,7 @@ import { describe, expect, test } from 'bun:test' import type { Message } from '../../../types/message.js' import { getSummaryContextFingerprint, + MAX_SUMMARY_CONTEXT_CHARS, selectSummaryContextMessages, } from '../summaryContext.js' @@ -101,6 +102,10 @@ describe('selectSummaryContextMessages', () => { }) describe('getSummaryContextFingerprint', () => { + test('returns null for an empty transcript', () => { + expect(getSummaryContextFingerprint([])).toBeNull() + }) + test('changes when the transcript grows', () => { const messages = [ makeMessage('user', 'u1', 'first prompt'), @@ -129,4 +134,16 @@ describe('getSummaryContextFingerprint', () => { expect(first).not.toBe(second) }) + + test('includes a truncation marker for oversized primitive values', () => { + const prefix = 'x'.repeat(MAX_SUMMARY_CONTEXT_CHARS + 100) + const first = getSummaryContextFingerprint([ + makeMessage('assistant', 'a1', `${prefix}a`), + ]) + const second = getSummaryContextFingerprint([ + makeMessage('assistant', 'a1', `${prefix}b`), + ]) + + expect(first).not.toBe(second) + }) }) diff --git a/src/services/AgentSummary/summaryContext.ts b/src/services/AgentSummary/summaryContext.ts index 4d9f6a6ce5..894a21e360 100644 --- a/src/services/AgentSummary/summaryContext.ts +++ b/src/services/AgentSummary/summaryContext.ts @@ -55,10 +55,15 @@ function updateFingerprintHash( if (limit.remaining <= 0) return if (value === null || typeof value !== 'object') { const text = String(value) + const consumed = Math.min(text.length, limit.remaining) + if (consumed <= 0) return hash.update(typeof value) hash.update(':') - hash.update(text.slice(0, limit.remaining)) - limit.remaining -= text.length + hash.update(text.slice(0, consumed)) + if (consumed < text.length) { + hash.update(`#truncated:${text.length}:${text.slice(-64)}`) + } + limit.remaining -= consumed return } if (seen.has(value)) { diff --git a/src/utils/__tests__/teammateMailbox.test.ts b/src/utils/__tests__/teammateMailbox.test.ts index 577c4331f4..7f479ed361 100644 --- a/src/utils/__tests__/teammateMailbox.test.ts +++ b/src/utils/__tests__/teammateMailbox.test.ts @@ -3,8 +3,10 @@ import { mkdir, readFile, rm, writeFile } from 'node:fs/promises' import { mkdtempSync } from 'node:fs' import { tmpdir } from 'node:os' import { dirname, join } from 'node:path' +import type { Message } from '../../types/message.js' import { compactMailboxMessages, + getLastPeerDmSummary, getInboxPath, markMessageAsReadByIndex, markMessageAsReadByIdentity, @@ -119,6 +121,23 @@ describe('compactMailboxMessages', () => { ]) }) + test('does not prioritize malformed JSON-like unread messages as protocol', () => { + const compacted = compactMailboxMessages( + [ + message('{not-json', false), + message('regular-1', false), + message('regular-2', false), + ], + { + maxMessages: 1, + maxReadMessages: 0, + maxUnreadProtocolMessages: 10, + }, + ) + + expect(compacted.map(m => m.text)).toEqual(['regular-2']) + }) + test('caps unread protocol messages with an independent bound', () => { const compacted = compactMailboxMessages( Array.from( @@ -308,3 +327,54 @@ describe('teammate mailbox retention', () => { await expect(readMailbox('worker', 'alpha')).rejects.toThrow() }) }) + +describe('getLastPeerDmSummary', () => { + test('extracts the final peer direct-message summary from assistant tool use', () => { + const messages = [ + { type: 'user', message: { content: 'wake up' } }, + { + type: 'assistant', + message: { + content: [ + { + type: 'tool_use', + name: 'SendMessage', + input: { + to: 'worker-1', + message: 'please check the UDS bounds', + summary: 'Checking UDS bounds', + }, + }, + ], + }, + }, + ] as unknown as Message[] + + expect(getLastPeerDmSummary(messages)).toBe( + '[to worker-1] Checking UDS bounds', + ) + }) + + test('stops peer direct-message summary search at the wake-up boundary', () => { + const messages = [ + { + type: 'assistant', + message: { + content: [ + { + type: 'tool_use', + name: 'SendMessage', + input: { + to: 'worker-1', + message: 'old message', + }, + }, + ], + }, + }, + { type: 'user', message: { content: 'new prompt' } }, + ] as unknown as Message[] + + expect(getLastPeerDmSummary(messages)).toBeUndefined() + }) +}) diff --git a/src/utils/__tests__/udsMessaging.test.ts b/src/utils/__tests__/udsMessaging.test.ts index 52cb57abc6..ef943cb76b 100644 --- a/src/utils/__tests__/udsMessaging.test.ts +++ b/src/utils/__tests__/udsMessaging.test.ts @@ -1,5 +1,13 @@ -import { afterEach, describe, expect, test } from 'bun:test' -import { chmod, mkdir, rm, stat, symlink, unlink } from 'node:fs/promises' +import { afterEach, beforeEach, describe, expect, test } from 'bun:test' +import { + chmod, + mkdir, + mkdtemp, + rm, + stat, + symlink, + unlink, +} from 'node:fs/promises' import { createConnection, createServer } from 'node:net' import { dirname, join } from 'node:path' import { tmpdir } from 'node:os' @@ -15,6 +23,9 @@ import { stopUdsMessaging, } from '../udsMessaging.js' +let previousConfigDir: string | undefined +let tempConfigDir = '' + function socketPath(label: string): string { const suffix = `${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}-${label}` if (process.platform === 'win32') { @@ -52,10 +63,25 @@ async function waitForEnqueues( setOnEnqueue(null) } +beforeEach(async () => { + previousConfigDir = process.env.CLAUDE_CONFIG_DIR + tempConfigDir = await mkdtemp(join(tmpdir(), 'uds-messaging-home-')) + process.env.CLAUDE_CONFIG_DIR = tempConfigDir +}) + afterEach(async () => { setOnEnqueue(null) drainInbox() await stopUdsMessaging() + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + if (tempConfigDir) { + await rm(tempConfigDir, { recursive: true, force: true }) + tempConfigDir = '' + } }) async function closeServer(server: ReturnType): Promise { @@ -133,6 +159,57 @@ describe('UDS inbox retention', () => { expect(drainInbox()).toEqual([]) }) + test('udsClient helpers authenticate through the capability file', async () => { + const path = socketPath('uds-client') + await startUdsMessaging(path, { isExplicit: true }) + const { isPeerAlive, sendToUdsSocket } = await import('../udsClient.js') + + expect(await isPeerAlive(path)).toBe(true) + await waitForEnqueues(1, async () => { + await sendToUdsSocket(path, 'hello from client') + }) + + const drained = drainInbox() + expect(drained).toHaveLength(1) + expect(drained[0]?.message.data).toBe('hello from client') + expect(drained[0]?.message.meta).toBeUndefined() + }) + + test('udsClient peer probe fails closed on oversized pong frames', async () => { + const path = socketPath('uds-client-oversized-pong') + if (process.platform !== 'win32') { + await mkdir(dirname(path), { recursive: true }) + } + const receiver = createServer(socket => { + socket.on('data', () => { + socket.write('x'.repeat(MAX_UDS_FRAME_BYTES + 1)) + }) + }) + await new Promise((resolve, reject) => { + receiver.on('error', reject) + receiver.listen(path, () => resolve()) + }) + + try { + const { isPeerAlive } = await import('../udsClient.js') + expect(await isPeerAlive(path)).toBe(false) + } finally { + await closeServer(receiver) + if (process.platform !== 'win32') { + await unlink(path).catch(() => undefined) + } + } + }) + + test('udsClient send fails closed when no capability token exists', async () => { + const path = socketPath('uds-client-no-token') + const { sendToUdsSocket } = await import('../udsClient.js') + + await expect(sendToUdsSocket(path, 'hello')).rejects.toThrow( + 'No auth token found', + ) + }) + test('drained entries never expose the UDS auth token', async () => { const path = socketPath('strip-token') await startUdsMessaging(path, { isExplicit: true }) @@ -301,5 +378,24 @@ describe('UDS inbox retention', () => { await rm(tempHome, { recursive: true, force: true }) } }) + + test('fails closed when an explicit socket parent is not private', async () => { + const parent = join( + tmpdir(), + `uds-socket-parent-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + await mkdir(parent, { recursive: true, mode: 0o755 }) + await chmod(parent, 0o755) + + try { + await expect( + startUdsMessaging(join(parent, 'messaging.sock'), { + isExplicit: true, + }), + ).rejects.toThrow('socket parent permissions are too broad') + } finally { + await rm(parent, { recursive: true, force: true }) + } + }) } }) diff --git a/src/utils/ndjsonFramer.ts b/src/utils/ndjsonFramer.ts index ecaa04dc8c..7832e93036 100644 --- a/src/utils/ndjsonFramer.ts +++ b/src/utils/ndjsonFramer.ts @@ -27,6 +27,7 @@ export function attachNdjsonFramer( options: NdjsonFramerOptions = {}, ): void { let buffer = '' + let bufferBytes = 0 const maxFrameBytes = options.maxFrameBytes ?? Number.POSITIVE_INFINITY const rejectOversizedFrame = (bytes: number): void => { @@ -37,41 +38,48 @@ export function attachNdjsonFramer( socket.destroy(error) } - socket.on('data', (chunk: Buffer) => { - if ( - Number.isFinite(maxFrameBytes) && - !chunk.includes(0x0a) && - Buffer.byteLength(buffer, 'utf8') + chunk.byteLength > maxFrameBytes - ) { - rejectOversizedFrame(Buffer.byteLength(buffer, 'utf8') + chunk.byteLength) - return + const emitLine = (line: string): void => { + if (!line.trim()) return + try { + onMessage(parse(line)) + } catch { + // Malformed JSON — skip } + } - buffer += chunk.toString() - const lines = buffer.split('\n') - buffer = lines.pop() ?? '' + socket.on('data', (chunk: Buffer) => { + let start = 0 + for (let index = 0; index < chunk.length; index++) { + if (chunk[index] !== 0x0a) continue - for (const line of lines) { - if (!line.trim()) continue + const segmentBytes = index - start if ( Number.isFinite(maxFrameBytes) && - Buffer.byteLength(line, 'utf8') > maxFrameBytes + bufferBytes + segmentBytes > maxFrameBytes ) { - rejectOversizedFrame(Buffer.byteLength(line, 'utf8')) + rejectOversizedFrame(bufferBytes + segmentBytes) return } - try { - onMessage(parse(line)) - } catch { - // Malformed JSON — skip - } + + buffer += chunk.subarray(start, index).toString('utf8') + emitLine(buffer) + buffer = '' + bufferBytes = 0 + start = index + 1 } + const tailBytes = chunk.length - start if ( Number.isFinite(maxFrameBytes) && - Buffer.byteLength(buffer, 'utf8') > maxFrameBytes + bufferBytes + tailBytes > maxFrameBytes ) { - rejectOversizedFrame(Buffer.byteLength(buffer, 'utf8')) + rejectOversizedFrame(bufferBytes + tailBytes) + return + } + + if (tailBytes > 0) { + buffer += chunk.subarray(start).toString('utf8') + bufferBytes += tailBytes } }) } diff --git a/src/utils/teammateMailbox.ts b/src/utils/teammateMailbox.ts index 6c18fd7213..ad9b22f931 100644 --- a/src/utils/teammateMailbox.ts +++ b/src/utils/teammateMailbox.ts @@ -8,7 +8,7 @@ */ import { randomBytes } from 'crypto' -import { mkdir, readFile, rename, stat, writeFile } from 'fs/promises' +import { mkdir, readFile, rename, stat, unlink, writeFile } from 'fs/promises' import { join } from 'path' import { z } from 'zod/v4' import { TEAMMATE_MESSAGE_TAG } from '../constants/xml.js' @@ -77,7 +77,7 @@ function shouldRetainUnreadAsProtocolMessage( 'type' in (parsed as Record), ) } catch { - return true + return false } } @@ -160,8 +160,13 @@ async function writeMailboxAtomic( ) } const tempPath = `${inboxPath}.${process.pid}.${randomBytes(8).toString('hex')}.tmp` - await writeFile(tempPath, content, 'utf-8') - await rename(tempPath, inboxPath) + try { + await writeFile(tempPath, content, 'utf-8') + await rename(tempPath, inboxPath) + } catch (error) { + await unlink(tempPath).catch(() => undefined) + throw error + } } export function compactMailboxMessages( diff --git a/src/utils/udsClient.ts b/src/utils/udsClient.ts index f08bee696b..e33ef3fdb7 100644 --- a/src/utils/udsClient.ts +++ b/src/utils/udsClient.ts @@ -17,6 +17,7 @@ import { isProcessRunning } from './genericProcessUtils.js' import { jsonParse, jsonStringify } from './slowOperations.js' import type { SessionKind } from './concurrentSessions.js' import { MAX_UDS_FRAME_BYTES, type UdsMessage } from './udsMessaging.js' +import { attachUdsResponseReader, getChunkBytes } from './udsResponseReader.js' // --------------------------------------------------------------------------- // Types @@ -43,12 +44,6 @@ function getSessionsDir(): string { return join(getClaudeConfigHomeDir(), 'sessions') } -function getChunkBytes(chunk: string | Buffer): number { - return typeof chunk === 'string' - ? Buffer.byteLength(chunk, 'utf8') - : chunk.byteLength -} - // --------------------------------------------------------------------------- // Discovery // --------------------------------------------------------------------------- @@ -218,56 +213,33 @@ export async function sendToUdsSocket( udsMsg.from = getUdsMessagingSocketPath() return new Promise((resolve, reject) => { - let buffer = '' let settled = false + let conn: ReturnType const finish = (error?: Error): void => { if (settled) return settled = true - conn.end() - if (error) reject(error) - else resolve() + if (error) { + conn.destroy(error) + reject(error) + } else { + conn.end() + resolve() + } } - const conn = createConnection(target.socketPath, () => { + + conn = createConnection(target.socketPath, () => { udsMsg.meta = { ...udsMsg.meta, authToken } conn.write(jsonStringify(udsMsg) + '\n', err => { if (err) finish(err) }) }) - conn.on('data', chunk => { - if ( - Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > - MAX_UDS_FRAME_BYTES - ) { - finish(new Error('UDS response frame exceeded size limit')) - return - } - buffer += chunk.toString() - const lines = buffer.split('\n') - buffer = lines.pop() ?? '' - for (const line of lines) { - if (!line.trim()) continue - let response: UdsMessage - try { - response = jsonParse(line) as UdsMessage - } catch { - continue - } - if (response.type === 'response') { - finish() - return - } - if (response.type === 'error') { - finish(new Error(response.data ?? 'UDS receiver rejected message')) - return - } - } - }) - conn.on('error', err => { - finish( + attachUdsResponseReader(conn, { + maxFrameBytes: MAX_UDS_FRAME_BYTES, + onSettled: finish, + formatSocketError: err => new Error( `Failed to connect to peer at ${target.socketPath}: ${errorMessage(err)}`, ), - ) }) conn.setTimeout(5000, () => { finish(new Error('Connection timed out')) diff --git a/src/utils/udsMessaging.ts b/src/utils/udsMessaging.ts index 7efa7fbf65..94b73dcd6a 100644 --- a/src/utils/udsMessaging.ts +++ b/src/utils/udsMessaging.ts @@ -8,7 +8,7 @@ * but can be overridden via --messaging-socket-path. */ -import { createHash, randomBytes } from 'crypto' +import { createHash, randomBytes, timingSafeEqual } from 'crypto' import { createServer, type Server, type Socket } from 'net' import { chmod, @@ -26,6 +26,7 @@ import { logForDebugging } from './debug.js' import { errorMessage } from './errors.js' import { getClaudeConfigHomeDir } from './envUtils.js' import { attachNdjsonFramer } from './ndjsonFramer.js' +import { attachUdsResponseReader } from './udsResponseReader.js' import { logError } from './log.js' import { jsonParse, jsonStringify } from './slowOperations.js' @@ -160,26 +161,36 @@ async function assertPrivateCapabilityDir(dir: string): Promise { stat = await lstat(dir) } + assertPrivateDirectory(stat, dir, 'capability directory') + await chmod(dir, 0o700) +} + +function assertPrivateDirectory( + stat: Awaited>, + dir: string, + label: string, +): void { if (!stat.isDirectory() || stat.isSymbolicLink()) { throw new Error( - `[udsMessaging] capability directory is not a private directory: ${dir}`, + `[udsMessaging] ${label} is not a private directory: ${dir}`, ) } if (process.platform !== 'win32') { - const broadMode = stat.mode & 0o077 + const broadMode = Number(stat.mode) & 0o077 if (broadMode !== 0) { throw new Error( - `[udsMessaging] capability directory permissions are too broad: ${dir}`, + `[udsMessaging] ${label} permissions are too broad: ${dir}`, ) } - if (typeof process.getuid === 'function' && stat.uid !== process.getuid()) { + if ( + typeof process.getuid === 'function' && + Number(stat.uid) !== process.getuid() + ) { throw new Error( - `[udsMessaging] capability directory owner does not match current user: ${dir}`, + `[udsMessaging] ${label} owner does not match current user: ${dir}`, ) } } - - await chmod(dir, 0o700) } async function writePrivateFileExclusive( @@ -204,6 +215,7 @@ async function ensureSocketParent(path: string): Promise { `[udsMessaging] socket parent is not a directory: ${dir}`, ) } + assertPrivateDirectory(stat, dir, 'socket parent') return } catch (error) { if (!isNotFound(error)) throw error @@ -314,7 +326,12 @@ function getMessageAuthToken(message: UdsMessage): string | undefined { } function isAuthorizedMessage(message: UdsMessage): boolean { - return getMessageAuthToken(message) === authToken + const provided = getMessageAuthToken(message) + if (!provided || !authToken) return false + const providedBuffer = Buffer.from(provided, 'utf8') + const expectedBuffer = Buffer.from(authToken, 'utf8') + if (providedBuffer.length !== expectedBuffer.length) return false + return timingSafeEqual(providedBuffer, expectedBuffer) } function writeSocketMessage(socket: Socket, message: UdsMessage): void { @@ -554,20 +571,6 @@ export async function stopUdsMessaging(): Promise { } } -function parseResponseLine(line: string): UdsMessage | null { - try { - return jsonParse(line) as UdsMessage - } catch { - return null - } -} - -function getChunkBytes(chunk: string | Buffer): number { - return typeof chunk === 'string' - ? Buffer.byteLength(chunk, 'utf8') - : chunk.byteLength -} - /** * Send a UDS message to a specific socket path (outbound — used when this * session wants to push a message to a peer's server). @@ -592,46 +595,30 @@ export async function sendUdsMessage( ) return new Promise((resolve, reject) => { - let buffer = '' let settled = false + let conn: ReturnType const finish = (error?: Error): void => { if (settled) return settled = true - conn.end() - if (error) reject(error) - else resolve() + if (error) { + conn.destroy(error) + reject(error) + } else { + conn.end() + resolve() + } } - const conn = createConnection(targetSocketPath, () => { + + conn = createConnection(targetSocketPath, () => { conn.write(jsonStringify(outbound) + '\n', err => { if (err) finish(err) }) }) - conn.on('data', chunk => { - if ( - Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > - MAX_UDS_FRAME_BYTES - ) { - finish(new Error('UDS response frame exceeded size limit')) - return - } - buffer += chunk.toString() - const lines = buffer.split('\n') - buffer = lines.pop() ?? '' - for (const line of lines) { - if (!line.trim()) continue - const response = parseResponseLine(line) - if (!response) continue - if (response.type === 'response' || response.type === 'pong') { - finish() - return - } - if (response.type === 'error') { - finish(new Error(response.data ?? 'UDS receiver rejected message')) - return - } - } + attachUdsResponseReader(conn, { + maxFrameBytes: MAX_UDS_FRAME_BYTES, + acceptPong: true, + onSettled: finish, }) - conn.on('error', err => finish(err)) // Timeout so we don't hang on unreachable sockets conn.setTimeout(5000, () => { finish(new Error('Connection timed out')) diff --git a/src/utils/udsResponseReader.ts b/src/utils/udsResponseReader.ts new file mode 100644 index 0000000000..bb8d21f40f --- /dev/null +++ b/src/utils/udsResponseReader.ts @@ -0,0 +1,81 @@ +import type { Socket } from 'net' +import { errorMessage } from './errors.js' +import { jsonParse } from './slowOperations.js' +import type { UdsMessage } from './udsMessaging.js' + +type UdsResponseReaderOptions = { + maxFrameBytes: number + acceptPong?: boolean + onSettled: (error?: Error) => void + formatSocketError?: (error: unknown) => Error +} + +export function getChunkBytes(chunk: string | Buffer): number { + return typeof chunk === 'string' + ? Buffer.byteLength(chunk, 'utf8') + : chunk.byteLength +} + +function parseResponseLine(line: string): UdsMessage | null { + try { + return jsonParse(line) as UdsMessage + } catch { + return null + } +} + +export function attachUdsResponseReader( + socket: Socket, + options: UdsResponseReaderOptions, +): void { + let buffer = '' + let settled = false + + const finish = (error?: Error): void => { + if (settled) return + settled = true + if (error) { + socket.destroy(error) + } else { + socket.end() + } + options.onSettled(error) + } + + socket.on('data', chunk => { + if ( + Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > + options.maxFrameBytes + ) { + finish(new Error('UDS response frame exceeded size limit')) + return + } + + buffer += chunk.toString() + const lines = buffer.split('\n') + buffer = lines.pop() ?? '' + for (const line of lines) { + if (!line.trim()) continue + const response = parseResponseLine(line) + if (!response) continue + if ( + response.type === 'response' || + (options.acceptPong === true && response.type === 'pong') + ) { + finish() + return + } + if (response.type === 'error') { + finish(new Error(response.data ?? 'UDS receiver rejected message')) + return + } + } + }) + + socket.on('error', error => { + finish( + options.formatSocketError?.(error) ?? + (error instanceof Error ? error : new Error(errorMessage(error))), + ) + }) +} From 4e2ff560204fc18299d30a96978c8cf5c0269dbb Mon Sep 17 00:00:00 2001 From: unraid Date: Mon, 27 Apr 2026 12:50:08 +0800 Subject: [PATCH 3/3] fix: prevent agent communication bounds from hiding CI regressions Tighten the UDS auth, framing, and response-reader boundaries while keeping the AgentSummary lifecycle covered so Codecov and CI fail on real regressions instead of missing coverage. The poorMode settings mock mirrors unrelated real settings defaults to avoid Bun mock retention changing later permission tests. Constraint: PR #369 must fix Codecov/CI precisely without warning suppression, fallback masking, or mock pollution Rejected: Delete AgentSummary lifecycle coverage | would hide Codecov loss and stale-summary behavior Rejected: Store inline UDS rejection in a hidden input sentinel | cloned observable inputs can drop it and bypass rejection Rejected: Ignore malformed UDS frames until timeout | leaves client slots and SendMessage calls open to exhaustion Confidence: high Scope-risk: moderate Directive: Keep empty #token= markers rejected; do not require a non-empty token value in hasInlineUdsToken Tested: bun test packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts src/utils/__tests__/udsMessaging.test.ts src/utils/__tests__/udsResponseReader.test.ts src/utils/__tests__/ndjsonFramer.test.ts Tested: bunx tsc --noEmit --pretty false Tested: bun run lint Tested: bun test --coverage --coverage-reporter lcov --coverage-dir coverage Tested: bun run test:all Tested: bun audit Tested: bun run build Tested: bun run build:vite Not-tested: GitHub-hosted Codecov upload until pushed PR checks rerun --- .../filterIncompleteToolCalls.test.ts | 180 ++++++++++++++ .../AgentTool/filterIncompleteToolCalls.ts | 110 +++++++++ .../src/tools/AgentTool/runAgent.ts | 47 +--- .../tools/SendMessageTool/SendMessageTool.ts | 46 ++-- .../udsRecipientSanitization.test.ts | 80 ++++++- src/commands/poor/__tests__/poorMode.test.ts | 47 +++- .../__tests__/agentSummary.test.ts | 201 ++++++++-------- .../__tests__/summaryContext.test.ts | 112 +++++++++ .../__tests__/summaryPrompt.test.ts | 34 +++ src/services/AgentSummary/agentSummary.ts | 102 ++++---- src/services/AgentSummary/summaryContext.ts | 41 +++- src/services/AgentSummary/summaryPrompt.ts | 32 +++ src/utils/__tests__/ndjsonFramer.test.ts | 62 +++++ src/utils/__tests__/teammateMailbox.test.ts | 116 +++++++-- src/utils/__tests__/udsMessaging.test.ts | 221 +++++++++++++++++- src/utils/__tests__/udsResponseReader.test.ts | 171 ++++++++++++++ src/utils/messages/systemInit.ts | 4 +- src/utils/ndjsonFramer.ts | 23 +- src/utils/swarm/inProcessRunner.ts | 14 +- src/utils/udsMessaging.ts | 157 ++++++++++--- src/utils/udsResponseReader.ts | 77 ++++-- 21 files changed, 1553 insertions(+), 324 deletions(-) create mode 100644 packages/builtin-tools/src/tools/AgentTool/__tests__/filterIncompleteToolCalls.test.ts create mode 100644 packages/builtin-tools/src/tools/AgentTool/filterIncompleteToolCalls.ts create mode 100644 src/services/AgentSummary/__tests__/summaryPrompt.test.ts create mode 100644 src/services/AgentSummary/summaryPrompt.ts create mode 100644 src/utils/__tests__/udsResponseReader.test.ts diff --git a/packages/builtin-tools/src/tools/AgentTool/__tests__/filterIncompleteToolCalls.test.ts b/packages/builtin-tools/src/tools/AgentTool/__tests__/filterIncompleteToolCalls.test.ts new file mode 100644 index 0000000000..5429c6ce81 --- /dev/null +++ b/packages/builtin-tools/src/tools/AgentTool/__tests__/filterIncompleteToolCalls.test.ts @@ -0,0 +1,180 @@ +import { describe, expect, test } from 'bun:test' +import type { Message } from 'src/types/message.js' +import { filterIncompleteToolCalls } from '../filterIncompleteToolCalls.js' + +describe('filterIncompleteToolCalls', () => { + test('drops assistant tool uses that do not have matching results', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [{ type: 'tool_use', id: 'missing', name: 'Read' }], + }, + }, + { + type: 'user', + uuid: 'u1', + message: { role: 'user', content: 'continue' }, + }, + ] as unknown as Message[] + + expect( + filterIncompleteToolCalls(messages).map(message => String(message.uuid)), + ).toEqual(['u1']) + }) + + test('preserves assistant text when dropping orphan tool uses', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [ + { type: 'text', text: 'I will read the file.' }, + { type: 'tool_use', id: 'missing', name: 'Read' }, + ], + }, + }, + ] as unknown as Message[] + + const filtered = filterIncompleteToolCalls(messages) + expect(filtered).toHaveLength(1) + const first = filtered[0]! + const content = first.message!.content + expect( + Array.isArray(content) ? content.map(block => block.type) : [], + ).toEqual(['text']) + }) + + test('keeps completed parallel tool calls when dropping an orphan', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [ + { type: 'tool_use', id: 'done', name: 'Read' }, + { type: 'tool_use', id: 'missing', name: 'Grep' }, + ], + }, + }, + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [{ type: 'tool_result', tool_use_id: 'done', content: 'ok' }], + }, + }, + ] as unknown as Message[] + + const filtered = filterIncompleteToolCalls(messages) + expect(filtered.map(message => String(message.uuid))).toEqual(['a1', 'u1']) + const first = filtered[0]! + const content = first.message!.content + expect( + Array.isArray(content) + ? content.map(block => + block.type === 'tool_use' ? block.id : block.type, + ) + : [], + ).toEqual(['done']) + }) + + test('keeps assistant tool uses that have matching results', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [{ type: 'tool_use', id: 'done', name: 'Read' }], + }, + }, + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [{ type: 'tool_result', tool_use_id: 'done', content: 'ok' }], + }, + }, + ] as unknown as Message[] + + expect( + filterIncompleteToolCalls(messages).map(message => String(message.uuid)), + ).toEqual(['a1', 'u1']) + }) + + test('drops orphan tool results when their tool use was removed', () => { + const messages = [ + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [ + { type: 'tool_result', tool_use_id: 'missing', content: 'late' }, + ], + }, + }, + ] as unknown as Message[] + + expect(filterIncompleteToolCalls(messages)).toEqual([]) + }) + + test('keeps user text while dropping orphan tool results', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { role: 'assistant', content: 'done' }, + }, + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [ + { type: 'text', text: 'keep this' }, + { type: 'tool_result', tool_use_id: 'missing', content: 'late' }, + ], + }, + }, + ] as unknown as Message[] + + const filtered = filterIncompleteToolCalls(messages) + expect(filtered.map(message => String(message.uuid))).toEqual(['a1', 'u1']) + const content = filtered[1]!.message!.content + expect(Array.isArray(content) ? content : []).toEqual([ + { type: 'text', text: 'keep this' }, + ]) + }) + + test('drops malformed tool blocks without ids', () => { + const messages = [ + { + type: 'assistant', + uuid: 'a1', + message: { + role: 'assistant', + content: [{ type: 'tool_use', name: 'Read' }], + }, + }, + { + type: 'user', + uuid: 'u1', + message: { + role: 'user', + content: [{ type: 'tool_result', content: 'late' }], + }, + }, + ] as unknown as Message[] + + expect(filterIncompleteToolCalls(messages)).toEqual([]) + }) +}) diff --git a/packages/builtin-tools/src/tools/AgentTool/filterIncompleteToolCalls.ts b/packages/builtin-tools/src/tools/AgentTool/filterIncompleteToolCalls.ts new file mode 100644 index 0000000000..7e30754eea --- /dev/null +++ b/packages/builtin-tools/src/tools/AgentTool/filterIncompleteToolCalls.ts @@ -0,0 +1,110 @@ +import type { + AssistantMessage, + Message, + UserMessage, +} from 'src/types/message.js' + +/** + * Removes invalid or orphaned tool_use/tool_result blocks while preserving + * completed tool-call pairs. This is intentionally block-level, not + * message-level, so completed parallel tool calls stay paired with results. + */ +export function filterIncompleteToolCalls(messages: Message[]): Message[] { + const toolUseIdsWithResults = new Set() + + for (const message of messages) { + if (message?.type === 'user') { + const userMessage = message as UserMessage + const content = userMessage.message.content + if (Array.isArray(content)) { + for (const block of content) { + if (block.type === 'tool_result' && block.tool_use_id) { + toolUseIdsWithResults.add(block.tool_use_id) + } + } + } + } + } + + const retainedToolUseIds = new Set() + const withoutOrphanToolUses: Message[] = [] + + for (const message of messages) { + if (message?.type === 'assistant') { + const assistantMessage = message as AssistantMessage + const content = assistantMessage.message.content + if (Array.isArray(content)) { + let changed = false + const filteredContent = content.filter(block => { + if (block.type !== 'tool_use') return true + if (!block.id) { + changed = true + return false + } + if (toolUseIdsWithResults.has(block.id)) { + retainedToolUseIds.add(block.id) + return true + } + changed = true + return false + }) + + if (!changed) { + withoutOrphanToolUses.push(message) + continue + } + if (filteredContent.length > 0) { + withoutOrphanToolUses.push({ + ...assistantMessage, + message: { + ...assistantMessage.message, + content: filteredContent, + }, + }) + } + continue + } + } + withoutOrphanToolUses.push(message) + } + + const filteredMessages: Message[] = [] + for (const message of withoutOrphanToolUses) { + if (message?.type !== 'user') { + filteredMessages.push(message) + continue + } + const userMessage = message as UserMessage + const content = userMessage.message.content + if (!Array.isArray(content)) { + filteredMessages.push(message) + continue + } + let changed = false + const filteredContent = content.filter(block => { + if (block.type !== 'tool_result') return true + if (!block.tool_use_id) { + changed = true + return false + } + if (retainedToolUseIds.has(block.tool_use_id)) return true + changed = true + return false + }) + if (!changed) { + filteredMessages.push(message) + continue + } + if (filteredContent.length > 0) { + filteredMessages.push({ + ...userMessage, + message: { + ...userMessage.message, + content: filteredContent, + }, + }) + } + } + + return filteredMessages +} diff --git a/packages/builtin-tools/src/tools/AgentTool/runAgent.ts b/packages/builtin-tools/src/tools/AgentTool/runAgent.ts index baeed9022d..de55b53f8f 100644 --- a/packages/builtin-tools/src/tools/AgentTool/runAgent.ts +++ b/packages/builtin-tools/src/tools/AgentTool/runAgent.ts @@ -86,8 +86,11 @@ import { import type { ContentReplacementState } from 'src/utils/toolResultStorage.js' import { createAgentId } from 'src/utils/uuid.js' import { resolveAgentTools } from './agentToolUtils.js' +import { filterIncompleteToolCalls } from './filterIncompleteToolCalls.js' import { type AgentDefinition, isBuiltInAgent } from './loadAgentsDir.js' +export { filterIncompleteToolCalls } from './filterIncompleteToolCalls.js' + /** * Initialize agent-specific MCP servers * Agents can define their own MCP servers in their frontmatter that are additive @@ -886,50 +889,6 @@ export async function* runAgent({ } } -/** - * Filters out assistant messages with incomplete tool calls (tool uses without results). - * This prevents API errors when sending messages with orphaned tool calls. - */ -export function filterIncompleteToolCalls(messages: Message[]): Message[] { - // Build a set of tool use IDs that have results - const toolUseIdsWithResults = new Set() - - for (const message of messages) { - if (message?.type === 'user') { - const userMessage = message as UserMessage - const content = userMessage.message.content - if (Array.isArray(content)) { - for (const block of content) { - if (block.type === 'tool_result' && block.tool_use_id) { - toolUseIdsWithResults.add(block.tool_use_id) - } - } - } - } - } - - // Filter out assistant messages that contain tool calls without results - return messages.filter(message => { - if (message?.type === 'assistant') { - const assistantMessage = message as AssistantMessage - const content = assistantMessage.message.content - if (Array.isArray(content)) { - // Check if this assistant message has any tool uses without results - const hasIncompleteToolCall = content.some( - block => - block.type === 'tool_use' && - block.id && - !toolUseIdsWithResults.has(block.id), - ) - // Exclude messages with incomplete tool calls - return !hasIncompleteToolCall - } - } - // Keep all non-assistant messages and assistant messages without tool calls - return true - }) -} - async function getAgentSystemPrompt( agentDefinition: AgentDefinition, toolUseContext: Pick, diff --git a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts index e4868bc531..cab0a03c5c 100644 --- a/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts +++ b/packages/builtin-tools/src/tools/SendMessageTool/SendMessageTool.ts @@ -131,15 +131,16 @@ export type SendMessageToolOutput = | ResponseOutput const UDS_INLINE_TOKEN_MARKER = '#token=' -const UDS_INLINE_TOKEN_REJECTED_KEY = '__udsInlineTokenRejected' function stripInlineUdsToken(target: string): string { - const markerIndex = target.lastIndexOf(UDS_INLINE_TOKEN_MARKER) + const markerIndex = target.indexOf(UDS_INLINE_TOKEN_MARKER) return markerIndex === -1 ? target : target.slice(0, markerIndex) } function hasInlineUdsToken(to: string): boolean { const addr = parseAddress(to) + // Empty-token markers are still inline-token attempts. Observable input + // redaction preserves "#token=" so cloned inputs remain rejected. return ( addr.scheme === 'uds' && addr.target.includes(UDS_INLINE_TOKEN_MARKER) ) @@ -151,20 +152,17 @@ function recipientForDisplay(to: string): string { return `uds:${stripInlineUdsToken(addr.target)}` } -function markAndRedactInlineUdsToken( - input: { to: string } & Record, -): void { - if (!hasInlineUdsToken(input.to)) return - input.to = recipientForDisplay(input.to) - input[UDS_INLINE_TOKEN_REJECTED_KEY] = true +function redactInlineUdsTokenForRejection(to: string): string { + const addr = parseAddress(to) + if (addr.scheme !== 'uds') return to + const markerIndex = addr.target.indexOf(UDS_INLINE_TOKEN_MARKER) + if (markerIndex === -1) return to + return `uds:${addr.target.slice(0, markerIndex)}${UDS_INLINE_TOKEN_MARKER}` } -function wasInlineUdsTokenRejected(input: unknown): boolean { - return ( - typeof input === 'object' && - input !== null && - (input as Record)[UDS_INLINE_TOKEN_REJECTED_KEY] === true - ) +function redactObservableInlineUdsToken(input: { to: string }): void { + if (!hasInlineUdsToken(input.to)) return + input.to = redactInlineUdsTokenForRejection(input.to) } function findTeammateColor( @@ -580,9 +578,7 @@ export const SendMessageTool: Tool = backfillObservableInput(input) { if (typeof input.to !== 'string') return - markAndRedactInlineUdsToken( - input as { to: string } & Record, - ) + redactObservableInlineUdsToken(input as { to: string }) if ('type' in input) return if (input.to === '*') { @@ -674,7 +670,7 @@ export const SendMessageTool: Tool = } if ( addr.scheme === 'uds' && - (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) + hasInlineUdsToken(input.to) ) { return { result: false, @@ -808,10 +804,7 @@ export const SendMessageTool: Tool = async call(input, context, canUseTool, assistantMessage) { if (typeof input.message === 'string') { const addr = parseAddress(input.to) - if ( - addr.scheme === 'uds' && - (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) - ) { + if (addr.scheme === 'uds' && hasInlineUdsToken(input.to)) { return { data: { success: false, @@ -857,15 +850,6 @@ export const SendMessageTool: Tool = } if (addr.scheme === 'uds') { const recipient = recipientForDisplay(input.to) - if (hasInlineUdsToken(input.to) || wasInlineUdsTokenRejected(input)) { - return { - data: { - success: false, - message: - 'uds addresses must not include inline auth tokens; use the ListPeers address', - }, - } - } /* eslint-disable @typescript-eslint/no-require-imports */ const { sendToUdsSocket } = require('src/utils/udsClient.js') as typeof import('src/utils/udsClient.js') diff --git a/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts index a0ab2af0dc..e0ce1a8233 100644 --- a/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts +++ b/packages/builtin-tools/src/tools/SendMessageTool/__tests__/udsRecipientSanitization.test.ts @@ -1,8 +1,8 @@ import { describe, expect, test } from 'bun:test' +import { SendMessageTool } from '../SendMessageTool.js' describe('SendMessageTool UDS recipient handling', () => { test('redacts inline UDS tokens before classifier and observable paths', async () => { - const { SendMessageTool } = await import('../SendMessageTool.js') const tokenAddress = 'uds:/tmp/peer.sock#token=secret-token' const observableInput = { @@ -12,6 +12,7 @@ describe('SendMessageTool UDS recipient handling', () => { SendMessageTool.backfillObservableInput!(observableInput) expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') + expect(observableInput.to).toBe('uds:/tmp/peer.sock#token=') expect(JSON.stringify(observableInput)).not.toContain('secret-token') expect( SendMessageTool.toAutoClassifierInput({ @@ -22,7 +23,6 @@ describe('SendMessageTool UDS recipient handling', () => { }) test('keeps redacted UDS token rejection through observable backfill', async () => { - const { SendMessageTool } = await import('../SendMessageTool.js') const observableInput = { to: 'uds:/tmp/peer.sock#token=secret-token', message: { @@ -35,7 +35,7 @@ describe('SendMessageTool UDS recipient handling', () => { SendMessageTool.backfillObservableInput!(observableInput) - expect(observableInput.to).toBe('uds:/tmp/peer.sock') + expect(observableInput.to).toBe('uds:/tmp/peer.sock#token=') expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') expect(observableInput.type).toBe('plan_approval_response') expect(observableInput.request_id).toBe('req-1') @@ -55,8 +55,37 @@ describe('SendMessageTool UDS recipient handling', () => { expect(result.message).toContain('inline auth tokens') }) + test('keeps inline-token rejection when observable input is cloned', async () => { + const observableInput = { + to: 'uds:/tmp/peer.sock#token=secret-token', + message: 'hello', + } as Record + + SendMessageTool.backfillObservableInput!(observableInput) + const clonedInput = { + to: observableInput.to, + message: observableInput.message, + summary: 'hello peer', + } + + const validation = await SendMessageTool.validateInput!( + clonedInput as never, + {} as never, + ) + const result = await SendMessageTool.call( + clonedInput as never, + {} as never, + undefined as never, + undefined as never, + ) + + expect(validation.result).toBe(false) + expect(result.data.success).toBe(false) + expect(JSON.stringify(clonedInput)).not.toContain('secret-token') + expect(JSON.stringify(result)).not.toContain('secret-token') + }) + test('redacts UDS tokens in structured classifier text', async () => { - const { SendMessageTool } = await import('../SendMessageTool.js') const to = 'uds:/tmp/peer.sock#token=secret-token' expect( @@ -75,10 +104,50 @@ describe('SendMessageTool UDS recipient handling', () => { }, }), ).toBe('plan_approval approve to uds:/tmp/peer.sock') + expect( + SendMessageTool.toAutoClassifierInput({ + to, + message: { + type: 'plan_approval_response', + request_id: 'req-2', + approve: false, + }, + }), + ).toBe('plan_approval reject to uds:/tmp/peer.sock') + expect( + SendMessageTool.toAutoClassifierInput({ + to, + message: { + type: 'shutdown_response', + request_id: 'shutdown-1', + approve: false, + }, + }), + ).toBe('shutdown_response reject shutdown-1') + }) + + test('redacts from the first inline UDS token marker', async () => { + const tokenAddress = 'uds:/tmp/peer.sock#token=first#token=second' + + const observableInput = { + to: tokenAddress, + message: 'hello', + } as Record + SendMessageTool.backfillObservableInput!(observableInput) + + expect(observableInput.to).toBe('uds:/tmp/peer.sock#token=') + expect(observableInput.recipient).toBe('uds:/tmp/peer.sock') + expect(JSON.stringify(observableInput)).not.toContain('first') + expect(JSON.stringify(observableInput)).not.toContain('second') + expect( + SendMessageTool.toAutoClassifierInput({ + to: tokenAddress, + message: 'hello', + }), + ).toBe('to uds:/tmp/peer.sock: hello') }) test('rejects inline UDS tokens during validation', async () => { - const { SendMessageTool } = await import('../SendMessageTool.js') const result = await SendMessageTool.validateInput!( { to: 'uds:/tmp/peer.sock#token=secret-token', @@ -96,7 +165,6 @@ describe('SendMessageTool UDS recipient handling', () => { }) test('rejects inline UDS tokens during execution without leaking them', async () => { - const { SendMessageTool } = await import('../SendMessageTool.js') const result = await SendMessageTool.call( { to: 'uds:/tmp/peer.sock#token=secret-token', diff --git a/src/commands/poor/__tests__/poorMode.test.ts b/src/commands/poor/__tests__/poorMode.test.ts index c2a80f3cf4..539c804e18 100644 --- a/src/commands/poor/__tests__/poorMode.test.ts +++ b/src/commands/poor/__tests__/poorMode.test.ts @@ -5,7 +5,8 @@ * After the fix, it reads from / writes to settings.json via * getInitialSettings() and updateSettingsForSource(). */ -import { describe, expect, test, beforeEach, mock } from 'bun:test' +import { afterAll, describe, expect, test, beforeEach, mock } from 'bun:test' +import * as settingsModule from '../../../utils/settings/settings.js' // ── Mocks must be declared before the module under test is imported ────────── @@ -13,24 +14,48 @@ let mockSettings: Record = {} let lastUpdate: { source: string; patch: Record } | null = null mock.module('src/utils/settings/settings.js', () => ({ + loadManagedFileSettings: () => ({ settings: null, errors: [] }), + getManagedFileSettingsPresence: () => ({ + hasBase: false, + hasDropIns: false, + }), + parseSettingsFile: () => ({ settings: null, errors: [] }), + getSettingsRootPathForSource: () => '', + getSettingsFilePathForSource: () => undefined, + getRelativeSettingsFilePathForSource: () => '', getInitialSettings: () => mockSettings, + getSettingsForSource: () => mockSettings, + getPolicySettingsOrigin: () => null, + getSettingsWithErrors: () => ({ settings: mockSettings, errors: [] }), + getSettingsWithSources: () => ({ effective: mockSettings, sources: [] }), + getSettings_DEPRECATED: () => mockSettings, + settingsMergeCustomizer: () => undefined, + getManagedSettingsKeysForLogging: () => [], + // Keep unrelated exports aligned with the real settings module so this + // full-surface mock cannot change later test files if Bun keeps it alive. + hasAutoModeOptIn: () => true, + hasSkipDangerousModePermissionPrompt: () => false, + getAutoModeConfig: () => undefined, + getUseAutoModeDuringPlan: () => true, + rawSettingsContainsKey: (key: string) => key in mockSettings, updateSettingsForSource: (source: string, patch: Record) => { lastUpdate = { source, patch } mockSettings = { ...mockSettings, ...patch } }, })) -// Import AFTER mocks are registered -const { isPoorModeActive, setPoorMode } = await import('../poorMode.js') - -// ── Helpers ────────────────────────────────────────────────────────────────── +afterAll(() => { + mock.restore() + mock.module('src/utils/settings/settings.js', () => settingsModule) +}) -/** Reset module-level singleton between tests by re-importing a fresh copy. */ -async function freshModule() { - // Bun caches modules; we manipulate the exported functions directly since - // the singleton `poorModeActive` is reset to null only on first import. - // Instead we test the observable behaviour through set/get pairs. -} +// Import AFTER mocks are registered. The query suffix gives this file its own +// module instance so cross-file poorMode.js mocks cannot replace the subject +// under test during Bun's shared coverage run. +const poorModeModulePath = '../poorMode.js?poorModeTest' +const { isPoorModeActive, setPoorMode } = (await import( + poorModeModulePath +)) as typeof import('../poorMode.js') // ── Tests ──────────────────────────────────────────────────────────────────── diff --git a/src/services/AgentSummary/__tests__/agentSummary.test.ts b/src/services/AgentSummary/__tests__/agentSummary.test.ts index 0ab0700800..368671f039 100644 --- a/src/services/AgentSummary/__tests__/agentSummary.test.ts +++ b/src/services/AgentSummary/__tests__/agentSummary.test.ts @@ -1,16 +1,11 @@ -import { - afterAll, - afterEach, - beforeEach, - describe, - expect, - mock, - test, -} from 'bun:test' -import { debugMock } from '../../../../tests/mocks/debug' -import { logMock } from '../../../../tests/mocks/log' +import { beforeEach, describe, expect, test } from 'bun:test' import { asAgentId } from '../../../types/ids.js' -import type { CacheSafeParams } from '../../../utils/forkedAgent.js' +import type { Message } from '../../../types/message.js' +import type { + CacheSafeParams, + ForkedAgentResult, +} from '../../../utils/forkedAgent.js' +import { startAgentSummarization } from '../agentSummary.js' const transcriptMessages = [ { type: 'user', message: { content: 'start' }, uuid: 'u1' }, @@ -20,114 +15,138 @@ const transcriptMessages = [ uuid: 'a1', }, { type: 'user', message: { content: 'continue' }, uuid: 'u2' }, -] +] as unknown as Message[] -let poorModeActive = false -let forkCalls = 0 -let updateCalls: Array<{ taskId: string; summary: string }> = [] -let transcript = { messages: transcriptMessages } -const sessionStorageSnapshot = { - ...(require('../../../utils/sessionStorage.ts') as Record), +type ForkCall = { + cacheSafeParams: CacheSafeParams } -mock.module('src/commands/poor/poorMode.js', () => ({ - isPoorModeActive: () => poorModeActive, -})) - -mock.module('src/tasks/LocalAgentTask/LocalAgentTask.js', () => ({ - updateAgentSummary: (taskId: string, summary: string) => { - updateCalls.push({ taskId, summary }) - }, -})) - -mock.module( - '@claude-code-best/builtin-tools/tools/AgentTool/runAgent.js', - () => ({ - filterIncompleteToolCalls: (messages: T) => messages, - }), -) - -mock.module('src/utils/debug.js', debugMock) -mock.module('src/utils/log.js', logMock) - -mock.module('src/utils/forkedAgent.js', () => ({ - runForkedAgent: async () => { - forkCalls += 1 - return { - messages: [ - { - type: 'assistant', - message: { - content: [{ type: 'text', text: 'Reading udsClient.ts' }], - }, - }, - ], - } - }, -})) - -mock.module('src/utils/sessionStorage.js', () => ({ - ...sessionStorageSnapshot, - getAgentTranscript: async () => transcript, -})) - -afterAll(() => { - mock.module('src/utils/sessionStorage.js', () => - require('../../../utils/sessionStorage.ts'), - ) -}) - describe('startAgentSummarization', () => { - const realSetTimeout = globalThis.setTimeout - const realClearTimeout = globalThis.clearTimeout - let scheduled: - | ((...args: Parameters void)>) => void) - | undefined + let scheduled: (() => void | Promise) | undefined + let handle: { stop: () => void } | undefined + let forkCalls: ForkCall[] + let updateCalls: Array<{ taskId: string; summary: string }> + let transcriptMessagesForTest: Message[] beforeEach(() => { - poorModeActive = false - forkCalls = 0 + forkCalls = [] updateCalls = [] - transcript = { messages: transcriptMessages } scheduled = undefined - globalThis.setTimeout = ((callback: TimerHandler) => { - scheduled = callback as (...args: unknown[]) => void - return 1 as unknown as ReturnType - }) as unknown as typeof setTimeout - globalThis.clearTimeout = (() => undefined) as typeof clearTimeout - }) - - afterEach(() => { - globalThis.setTimeout = realSetTimeout - globalThis.clearTimeout = realClearTimeout + handle = undefined + transcriptMessagesForTest = transcriptMessages }) test('summarizes bounded transcript once and skips unchanged fingerprints', async () => { - const { startAgentSummarization } = await import('../agentSummary.js') - - const handle = startAgentSummarization( + handle = startAgentSummarization( 'task-1', asAgentId('a0000000000000000'), { - forkContextMessages: [{ type: 'user', message: { content: 'old' } }], + forkContextMessages: [ + { type: 'user', message: { content: 'stale' }, uuid: 'old' }, + ], model: 'claude-test', } as unknown as CacheSafeParams, () => undefined, + { + clearTimeout: () => undefined, + getAgentTranscript: async () => ({ + messages: transcriptMessagesForTest, + contentReplacements: [], + }), + isPoorModeActive: () => false, + logError: () => undefined, + logForDebugging: () => undefined, + runForkedAgent: async (args: ForkCall) => { + forkCalls.push(args) + return { + messages: [ + { + type: 'assistant', + message: { + content: [{ type: 'text', text: 'Reading udsClient.ts' }], + }, + }, + ], + } as unknown as ForkedAgentResult + }, + setTimeout: ((callback: TimerHandler) => { + if (typeof callback !== 'function') { + throw new Error('Expected timer callback') + } + scheduled = callback as () => void | Promise + return 1 as unknown as ReturnType + }) as unknown as typeof setTimeout, + updateAgentSummary: (taskId: string, summary: string) => { + updateCalls.push({ taskId, summary }) + }, + }, ) expect(typeof scheduled).toBe('function') await scheduled!() - expect(forkCalls).toBe(1) + expect(forkCalls).toHaveLength(1) expect(updateCalls).toEqual([ { taskId: 'task-1', summary: 'Reading udsClient.ts' }, ]) + const forkContext = forkCalls[0].cacheSafeParams.forkContextMessages ?? [] + expect(forkContext.map(message => String(message.uuid))).toEqual([ + 'u1', + 'a1', + 'u2', + ]) + expect(forkContext.some(message => String(message.uuid) === 'old')).toBe( + false, + ) + await scheduled!() - expect(forkCalls).toBe(1) + expect(forkCalls).toHaveLength(1) expect(updateCalls).toHaveLength(1) + }) + + test('skips summarization when bounded context is too small', async () => { + transcriptMessagesForTest = transcriptMessages.slice(0, 2) + + handle = startAgentSummarization( + 'task-1', + asAgentId('a0000000000000000'), + { + forkContextMessages: transcriptMessages, + model: 'claude-test', + } as unknown as CacheSafeParams, + () => undefined, + { + clearTimeout: () => undefined, + getAgentTranscript: async () => ({ + messages: transcriptMessagesForTest, + contentReplacements: [], + }), + isPoorModeActive: () => false, + logError: () => undefined, + logForDebugging: () => undefined, + runForkedAgent: async (args: ForkCall) => { + forkCalls.push(args) + return { messages: [] } as unknown as ForkedAgentResult + }, + setTimeout: ((callback: TimerHandler) => { + if (typeof callback !== 'function') { + throw new Error('Expected timer callback') + } + scheduled = callback as () => void | Promise + return 1 as unknown as ReturnType + }) as unknown as typeof setTimeout, + updateAgentSummary: (taskId: string, summary: string) => { + updateCalls.push({ taskId, summary }) + }, + }, + ) + + expect(typeof scheduled).toBe('function') + await scheduled!() - handle.stop() + expect(forkCalls).toEqual([]) + expect(updateCalls).toEqual([]) }) }) diff --git a/src/services/AgentSummary/__tests__/summaryContext.test.ts b/src/services/AgentSummary/__tests__/summaryContext.test.ts index 1c701d14b4..fe0eb30572 100644 --- a/src/services/AgentSummary/__tests__/summaryContext.test.ts +++ b/src/services/AgentSummary/__tests__/summaryContext.test.ts @@ -1,6 +1,8 @@ import { describe, expect, test } from 'bun:test' import type { Message } from '../../../types/message.js' import { + buildSummaryContext, + estimateMessageChars, getSummaryContextFingerprint, MAX_SUMMARY_CONTEXT_CHARS, selectSummaryContextMessages, @@ -75,6 +77,21 @@ describe('selectSummaryContextMessages', () => { expect(selected).toEqual([]) }) + test('stops at an older oversized message after keeping the recent suffix', () => { + const messages = [ + makeMessage('user', 'u1', 'x'.repeat(5_000)), + makeMessage('user', 'u2', 'small prompt'), + makeMessage('assistant', 'a2', 'small answer'), + ] + + const selected = selectSummaryContextMessages(messages, { + maxMessages: 10, + maxChars: 1_000, + }) + + expect(selected.map(message => String(message.uuid))).toEqual(['u2', 'a2']) + }) + test('drops leading orphan tool results after bounding', () => { const messages = [ makeMessage('assistant', 'a0', 'older assistant'), @@ -102,6 +119,28 @@ describe('selectSummaryContextMessages', () => { }) describe('getSummaryContextFingerprint', () => { + test('estimates circular messages as unbounded', () => { + const circular = makeMessage('assistant', 'a1', 'cycle') as Message & { + self?: unknown + } + circular.self = circular + + expect(estimateMessageChars(circular)).toBe(Number.POSITIVE_INFINITY) + }) + + test('ignores non-json primitive fields in size estimates', () => { + const message = makeMessage('assistant', 'a1', 'metadata') as Message & { + skipUndefined?: undefined + skipFunction?: () => void + skipSymbol?: symbol + } + message.skipUndefined = undefined + message.skipFunction = () => undefined + message.skipSymbol = Symbol('ignored') + + expect(estimateMessageChars(message)).toBeGreaterThan(0) + }) + test('returns null for an empty transcript', () => { expect(getSummaryContextFingerprint([])).toBeNull() }) @@ -146,4 +185,77 @@ describe('getSummaryContextFingerprint', () => { expect(first).not.toBe(second) }) + + test('fingerprints circular message references without recursing forever', () => { + const circular = makeMessage('assistant', 'a1', 'cycle') as Message & { + self?: unknown + } + circular.self = circular + + expect(getSummaryContextFingerprint([circular])).toContain(':a1:') + }) +}) + +describe('buildSummaryContext', () => { + test('returns bounded messages and fingerprint for summarizable context', () => { + const messages = [ + { type: 'user', uuid: 'u1', message: { content: 'start' } }, + { + type: 'assistant', + uuid: 'a1', + message: { content: [{ type: 'text', text: 'working' }] }, + }, + { type: 'user', uuid: 'u2', message: { content: 'continue' } }, + ] as unknown as Message[] + + const result = buildSummaryContext(messages, null) + + expect(result.skipReason).toBeUndefined() + expect(result.messages.map(message => String(message.uuid))).toEqual([ + 'u1', + 'a1', + 'u2', + ]) + expect(result.fingerprint).toContain('3:u2:') + }) + + test('reports unchanged contexts by fingerprint', () => { + const messages = [ + { type: 'user', uuid: 'u1', message: { content: 'start' } }, + { + type: 'assistant', + uuid: 'a1', + message: { content: [{ type: 'text', text: 'working' }] }, + }, + { type: 'user', uuid: 'u2', message: { content: 'continue' } }, + ] as unknown as Message[] + const first = buildSummaryContext(messages, null) + + const second = buildSummaryContext(messages, first.fingerprint) + + expect(second.skipReason).toBe('unchanged') + expect(second.fingerprint).toBe(first.fingerprint) + }) + + test('filters incomplete tool calls before deciding context is too small', () => { + const messages = [ + { type: 'user', uuid: 'u1', message: { content: 'start' } }, + { + type: 'assistant', + uuid: 'a1', + message: { + content: [{ type: 'tool_use', id: 'missing', name: 'Read' }], + }, + }, + { type: 'user', uuid: 'u2', message: { content: 'continue' } }, + ] as unknown as Message[] + + const result = buildSummaryContext(messages, null) + + expect(result.skipReason).toBe('too_small') + expect(result.messages.map(message => String(message.uuid))).toEqual([ + 'u1', + 'u2', + ]) + }) }) diff --git a/src/services/AgentSummary/__tests__/summaryPrompt.test.ts b/src/services/AgentSummary/__tests__/summaryPrompt.test.ts new file mode 100644 index 0000000000..9e8f03cac6 --- /dev/null +++ b/src/services/AgentSummary/__tests__/summaryPrompt.test.ts @@ -0,0 +1,34 @@ +import { describe, expect, test } from 'bun:test' +import { + buildSummaryPrompt, + createSummaryPromptMessage, +} from '../summaryPrompt.js' + +describe('buildSummaryPrompt', () => { + test('builds the first summary prompt without previous-summary pressure', () => { + const prompt = buildSummaryPrompt(null) + + expect(prompt).toContain('Describe your most recent action') + expect(prompt).toContain('Good: "Reading runAgent.ts"') + expect(prompt).not.toContain('Previous:') + }) + + test('asks for a new summary when a previous one exists', () => { + const prompt = buildSummaryPrompt('Reading udsMessaging.ts') + + expect(prompt).toContain('Previous: "Reading udsMessaging.ts"') + expect(prompt).toContain('say something NEW') + }) +}) + +describe('createSummaryPromptMessage', () => { + test('creates the minimal user message shape used by forked summaries', () => { + const message = createSummaryPromptMessage('Summarize progress') + + expect(message.type).toBe('user') + expect(message.message.role).toBe('user') + expect(message.message.content).toBe('Summarize progress') + expect(message.uuid).toBeString() + expect(message.timestamp).toBeString() + }) +}) diff --git a/src/services/AgentSummary/agentSummary.ts b/src/services/AgentSummary/agentSummary.ts index 2232e839d7..d212a5c72c 100644 --- a/src/services/AgentSummary/agentSummary.ts +++ b/src/services/AgentSummary/agentSummary.ts @@ -13,7 +13,6 @@ import type { TaskContext } from '../../Task.js' import { isPoorModeActive } from '../../commands/poor/poorMode.js' import { updateAgentSummary } from '../../tasks/LocalAgentTask/LocalAgentTask.js' -import { filterIncompleteToolCalls } from '@claude-code-best/builtin-tools/tools/AgentTool/runAgent.js' import type { AgentId } from '../../types/ids.js' import { logForDebugging } from '../../utils/debug.js' import { @@ -21,38 +20,32 @@ import { runForkedAgent, } from '../../utils/forkedAgent.js' import { logError } from '../../utils/log.js' -import { createUserMessage } from '../../utils/messages.js' import { getAgentTranscript } from '../../utils/sessionStorage.js' +import { buildSummaryContext } from './summaryContext.js' import { - getSummaryContextFingerprint, - selectSummaryContextMessages, -} from './summaryContext.js' + buildSummaryPrompt, + createSummaryPromptMessage, +} from './summaryPrompt.js' const SUMMARY_INTERVAL_MS = 30_000 -function buildSummaryPrompt(previousSummary: string | null): string { - const prevLine = previousSummary - ? `\nPrevious: "${previousSummary}" — say something NEW.\n` - : '' - - return `Describe your most recent action in 3-5 words using present tense (-ing). Name the file or function, not the branch. Do not use tools. -${prevLine} -Good: "Reading runAgent.ts" -Good: "Fixing null check in validate.ts" -Good: "Running auth module tests" -Good: "Adding retry logic to fetchUser" - -Bad (past tense): "Analyzed the branch diff" -Bad (too vague): "Investigating the issue" -Bad (too long): "Reviewing full branch diff and AgentTool.tsx integration" -Bad (branch name): "Analyzed adam/background-summary branch diff"` -} +export type AgentSummaryDependencies = Partial<{ + clearTimeout: typeof clearTimeout + getAgentTranscript: typeof getAgentTranscript + isPoorModeActive: typeof isPoorModeActive + logError: typeof logError + logForDebugging: typeof logForDebugging + runForkedAgent: typeof runForkedAgent + setTimeout: typeof setTimeout + updateAgentSummary: typeof updateAgentSummary +}> export function startAgentSummarization( taskId: string, agentId: AgentId, cacheSafeParams: CacheSafeParams, setAppState: TaskContext['setAppState'], + dependencies: AgentSummaryDependencies = {}, ): { stop: () => void } { // Drop forkContextMessages from the closure — runSummary rebuilds it each // tick from getAgentTranscript(). Without this, the original fork messages @@ -63,46 +56,53 @@ export function startAgentSummarization( let stopped = false let previousSummary: string | null = null let lastHandledTranscriptFingerprint: string | null = null + const clearTimeoutImpl = dependencies.clearTimeout ?? clearTimeout + const getAgentTranscriptImpl = + dependencies.getAgentTranscript ?? getAgentTranscript + const isPoorModeActiveImpl = + dependencies.isPoorModeActive ?? isPoorModeActive + const logErrorImpl = dependencies.logError ?? logError + const logForDebuggingImpl = + dependencies.logForDebugging ?? logForDebugging + const runForkedAgentImpl = dependencies.runForkedAgent ?? runForkedAgent + const setTimeoutImpl = dependencies.setTimeout ?? setTimeout + const updateAgentSummaryImpl = + dependencies.updateAgentSummary ?? updateAgentSummary async function runSummary(): Promise { if (stopped) return - if (isPoorModeActive()) { - logForDebugging('[AgentSummary] Skipping summary — poor mode active') + if (isPoorModeActiveImpl()) { + logForDebuggingImpl('[AgentSummary] Skipping summary — poor mode active') scheduleNext() return } - logForDebugging(`[AgentSummary] Timer fired for agent ${agentId}`) + logForDebuggingImpl(`[AgentSummary] Timer fired for agent ${agentId}`) try { // Read current messages from transcript - const transcript = await getAgentTranscript(agentId) + const transcript = await getAgentTranscriptImpl(agentId) if (!transcript || transcript.messages.length < 3) { // Not enough context yet — finally block will schedule next attempt - logForDebugging( + logForDebuggingImpl( `[AgentSummary] Skipping summary for ${taskId}: not enough messages (${transcript?.messages.length ?? 0})`, ) return } - // Filter to clean message state - const cleanMessages = filterIncompleteToolCalls(transcript.messages) - const summaryContext = filterIncompleteToolCalls( - selectSummaryContextMessages(cleanMessages), + const summaryContext = buildSummaryContext( + transcript.messages, + lastHandledTranscriptFingerprint, ) - const transcriptFingerprint = getSummaryContextFingerprint(summaryContext) - if ( - transcriptFingerprint && - transcriptFingerprint === lastHandledTranscriptFingerprint - ) { - logForDebugging( + if (summaryContext.skipReason === 'unchanged') { + logForDebuggingImpl( `[AgentSummary] Skipping summary for ${taskId}: transcript unchanged`, ) return } - if (summaryContext.length < 3) { - logForDebugging( + if (summaryContext.skipReason === 'too_small') { + logForDebuggingImpl( `[AgentSummary] Skipping summary for ${taskId}: no bounded context available`, ) return @@ -111,11 +111,11 @@ export function startAgentSummarization( // Build fork params with current messages const forkParams: CacheSafeParams = { ...baseParams, - forkContextMessages: summaryContext, + forkContextMessages: summaryContext.messages, } - logForDebugging( - `[AgentSummary] Forking for summary, ${summaryContext.length} messages in context`, + logForDebuggingImpl( + `[AgentSummary] Forking for summary, ${summaryContext.messages.length} messages in context`, ) // Create abort controller for this summary @@ -137,9 +137,9 @@ export function startAgentSummarization( // ContentReplacementState is cloned by default in createSubagentContext // from forkParams.toolUseContext (the subagent's LIVE state captured at // onCacheSafeParams time). No explicit override needed. - const result = await runForkedAgent({ + const result = await runForkedAgentImpl({ promptMessages: [ - createUserMessage({ content: buildSummaryPrompt(previousSummary) }), + createSummaryPromptMessage(buildSummaryPrompt(previousSummary)), ], cacheSafeParams: forkParams, canUseTool, @@ -167,18 +167,18 @@ export function startAgentSummarization( const textBlock = contentArr.find(b => b.type === 'text') if (textBlock?.type === 'text' && textBlock.text.trim()) { const summaryText = textBlock.text.trim() - logForDebugging( + logForDebuggingImpl( `[AgentSummary] Summary result for ${taskId}: ${summaryText}`, ) - lastHandledTranscriptFingerprint = transcriptFingerprint + lastHandledTranscriptFingerprint = summaryContext.fingerprint previousSummary = summaryText - updateAgentSummary(taskId, summaryText, setAppState) + updateAgentSummaryImpl(taskId, summaryText, setAppState) break } } } catch (e) { if (!stopped && e instanceof Error) { - logError(e) + logErrorImpl(e) } } finally { summaryAbortController = null @@ -191,14 +191,14 @@ export function startAgentSummarization( function scheduleNext(): void { if (stopped) return - timeoutId = setTimeout(runSummary, SUMMARY_INTERVAL_MS) + timeoutId = setTimeoutImpl(runSummary, SUMMARY_INTERVAL_MS) } function stop(): void { - logForDebugging(`[AgentSummary] Stopping summarization for ${taskId}`) + logForDebuggingImpl(`[AgentSummary] Stopping summarization for ${taskId}`) stopped = true if (timeoutId) { - clearTimeout(timeoutId) + clearTimeoutImpl(timeoutId) timeoutId = null } if (summaryAbortController) { diff --git a/src/services/AgentSummary/summaryContext.ts b/src/services/AgentSummary/summaryContext.ts index 894a21e360..d4c00e1d42 100644 --- a/src/services/AgentSummary/summaryContext.ts +++ b/src/services/AgentSummary/summaryContext.ts @@ -1,4 +1,5 @@ -import { createHash } from 'crypto' +import { createHash } from 'node:crypto' +import { filterIncompleteToolCalls } from '@claude-code-best/builtin-tools/tools/AgentTool/filterIncompleteToolCalls.js' import type { Message } from '../../types/message.js' export const MAX_SUMMARY_CONTEXT_MESSAGES = 120 @@ -178,3 +179,41 @@ export function selectSummaryContextMessages( return selected } + +export type SummaryContextBuildResult = { + messages: Message[] + fingerprint: string | null + skipReason?: 'too_small' | 'unchanged' +} + +export function buildSummaryContext( + messages: Message[], + previousFingerprint: string | null, +): SummaryContextBuildResult { + const cleanMessages = filterIncompleteToolCalls(messages) + const boundedMessages = filterIncompleteToolCalls( + selectSummaryContextMessages(cleanMessages), + ) + const fingerprint = getSummaryContextFingerprint(boundedMessages) + + if (fingerprint && fingerprint === previousFingerprint) { + return { + messages: boundedMessages, + fingerprint, + skipReason: 'unchanged', + } + } + + if (boundedMessages.length < 3) { + return { + messages: boundedMessages, + fingerprint, + skipReason: 'too_small', + } + } + + return { + messages: boundedMessages, + fingerprint, + } +} diff --git a/src/services/AgentSummary/summaryPrompt.ts b/src/services/AgentSummary/summaryPrompt.ts new file mode 100644 index 0000000000..ce3138f2a3 --- /dev/null +++ b/src/services/AgentSummary/summaryPrompt.ts @@ -0,0 +1,32 @@ +import { randomUUID, type UUID } from 'node:crypto' +import type { UserMessage } from '../../types/message.js' + +export function buildSummaryPrompt(previousSummary: string | null): string { + const prevLine = previousSummary + ? `\nPrevious: "${previousSummary}" — say something NEW.\n` + : '' + + return `Describe your most recent action in 3-5 words using present tense (-ing). Name the file or function, not the branch. Do not use tools. +${prevLine} +Good: "Reading runAgent.ts" +Good: "Fixing null check in validate.ts" +Good: "Running auth module tests" +Good: "Adding retry logic to fetchUser" + +Bad (past tense): "Analyzed the branch diff" +Bad (too vague): "Investigating the issue" +Bad (too long): "Reviewing full branch diff and AgentTool.tsx integration" +Bad (branch name): "Analyzed adam/background-summary branch diff"` +} + +export function createSummaryPromptMessage(content: string): UserMessage { + return { + type: 'user', + message: { + role: 'user', + content, + }, + uuid: randomUUID() as UUID, + timestamp: new Date().toISOString(), + } +} diff --git a/src/utils/__tests__/ndjsonFramer.test.ts b/src/utils/__tests__/ndjsonFramer.test.ts index 35174162a3..344c1e58c3 100644 --- a/src/utils/__tests__/ndjsonFramer.test.ts +++ b/src/utils/__tests__/ndjsonFramer.test.ts @@ -88,4 +88,66 @@ describe('attachNdjsonFramer', () => { expect(errors[0]?.message).toContain('NDJSON frame exceeded') expect(socket.destroyed).toBe(true) }) + + test('lets callers own oversized-frame shutdown when configured', () => { + const socket = createTestSocket() + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + () => undefined, + text => JSON.parse(text) as unknown, + { + maxFrameBytes: 8, + onFrameError: error => errors.push(error), + destroyOnFrameError: false, + }, + ) + + socket.emitData(Buffer.from('{"long":true}\n')) + + expect(errors[0]?.message).toContain('NDJSON frame exceeded') + expect(socket.destroyed).toBe(false) + }) + + test('reports malformed non-empty frames without changing default compatibility', () => { + const socket = createTestSocket() + const messages: unknown[] = [] + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + msg => messages.push(msg), + text => JSON.parse(text) as unknown, + { + onInvalidFrame: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{not-json\n')) + + expect(messages).toEqual([]) + expect(errors).toHaveLength(1) + expect(socket.destroyed).toBe(false) + }) + + test('destroys malformed frames when configured by the caller', () => { + const socket = createTestSocket() + const errors: Error[] = [] + + attachNdjsonFramer( + socket, + () => undefined, + text => JSON.parse(text) as unknown, + { + destroyOnInvalidFrame: true, + onInvalidFrame: error => errors.push(error), + }, + ) + + socket.emitData(Buffer.from('{not-json\n')) + + expect(errors).toHaveLength(1) + expect(socket.destroyed).toBe(true) + }) }) diff --git a/src/utils/__tests__/teammateMailbox.test.ts b/src/utils/__tests__/teammateMailbox.test.ts index 7f479ed361..f6279dab7a 100644 --- a/src/utils/__tests__/teammateMailbox.test.ts +++ b/src/utils/__tests__/teammateMailbox.test.ts @@ -3,7 +3,7 @@ import { mkdir, readFile, rm, writeFile } from 'node:fs/promises' import { mkdtempSync } from 'node:fs' import { tmpdir } from 'node:os' import { dirname, join } from 'node:path' -import type { Message } from '../../types/message.js' +import type { Message } from 'src/types/message.js' import { compactMailboxMessages, getLastPeerDmSummary, @@ -13,13 +13,14 @@ import { markMessagesAsRead, markMessagesAsReadByPredicate, MAX_MAILBOX_MESSAGE_TEXT_BYTES, + MAX_MAILBOX_FILE_BYTES, MAX_MAILBOX_MESSAGES, MAX_READ_MAILBOX_MESSAGES, MAX_UNREAD_PROTOCOL_MAILBOX_MESSAGES, readMailbox, type TeammateMessage, writeToMailbox, -} from '../teammateMailbox.js' +} from 'src/utils/teammateMailbox.js' let tempHome = '' let previousConfigDir: string | undefined @@ -55,21 +56,6 @@ async function readRawMailbox( return JSON.parse(content) as TeammateMessage[] } -beforeEach(() => { - previousConfigDir = process.env.CLAUDE_CONFIG_DIR - tempHome = mkdtempSync(join(tmpdir(), 'teammate-mailbox-')) - process.env.CLAUDE_CONFIG_DIR = tempHome -}) - -afterEach(async () => { - if (previousConfigDir === undefined) { - delete process.env.CLAUDE_CONFIG_DIR - } else { - process.env.CLAUDE_CONFIG_DIR = previousConfigDir - } - await rm(tempHome, { recursive: true, force: true }) -}) - describe('compactMailboxMessages', () => { test('prioritizes unread messages and keeps only recent read history', () => { const compacted = compactMailboxMessages( @@ -175,9 +161,35 @@ describe('compactMailboxMessages', () => { expect(compacted.length).toBeLessThan(20) expect(compacted.at(-1)?.text).toContain('msg-19') }) + + test('returns an empty mailbox when even one message exceeds retained budget', () => { + const compacted = compactMailboxMessages([message('too-large', false)], { + maxMessages: 10, + maxReadMessages: 0, + maxRetainedBytes: 1, + }) + + expect(compacted).toEqual([]) + }) }) describe('teammate mailbox retention', () => { + beforeEach(() => { + previousConfigDir = process.env.CLAUDE_CONFIG_DIR + tempHome = mkdtempSync(join(tmpdir(), 'teammate-mailbox-')) + process.env.CLAUDE_CONFIG_DIR = tempHome + }) + + afterEach(async () => { + if (previousConfigDir === undefined) { + delete process.env.CLAUDE_CONFIG_DIR + } else { + process.env.CLAUDE_CONFIG_DIR = previousConfigDir + } + await rm(tempHome, { recursive: true, force: true }) + tempHome = '' + }) + test('writeToMailbox compacts oversized unread inbox files', async () => { const existing = Array.from( { length: MAX_MAILBOX_MESSAGES + 20 }, @@ -326,6 +338,76 @@ describe('teammate mailbox retention', () => { await expect(readMailbox('worker', 'alpha')).rejects.toThrow() }) + + test('readMailbox rejects non-array mailbox files', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, JSON.stringify({ text: 'not an array' }), 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow( + 'expected message array', + ) + }) + + test('readMailbox rejects malformed stored message shapes', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile( + inboxPath, + JSON.stringify([{ from: 'lead', text: 'missing timestamp' }]), + 'utf-8', + ) + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow( + 'Invalid mailbox message shape', + ) + }) + + test('readMailbox rejects non-object stored messages', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, JSON.stringify(['not an object']), 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow( + 'expected object', + ) + }) + + test('readMailbox rejects oversized mailbox files before parsing', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, `[${' '.repeat(MAX_MAILBOX_FILE_BYTES)}]`, 'utf-8') + + await expect(readMailbox('worker', 'alpha')).rejects.toThrow( + 'Mailbox file exceeds', + ) + }) + + test('markMessageAsReadByIdentity returns false for missing mailbox files', async () => { + await expect( + markMessageAsReadByIdentity('worker', 'alpha', message('absent', false)), + ).resolves.toBe(false) + }) + + test('markMessageAsReadByIdentity returns false when the expected message moved out', async () => { + await seedMailbox('worker', 'alpha', [message('other', false)]) + + await expect( + markMessageAsReadByIdentity('worker', 'alpha', message('missing', false)), + ).resolves.toBe(false) + + expect((await readRawMailbox('worker', 'alpha'))[0]?.read).toBe(false) + }) + + test('markMessageAsReadByIdentity returns false on corrupt mailbox content', async () => { + const inboxPath = getInboxPath('worker', 'alpha') + await mkdir(dirname(inboxPath), { recursive: true }) + await writeFile(inboxPath, '{not-json', 'utf-8') + + await expect( + markMessageAsReadByIdentity('worker', 'alpha', message('missing', false)), + ).resolves.toBe(false) + }) }) describe('getLastPeerDmSummary', () => { diff --git a/src/utils/__tests__/udsMessaging.test.ts b/src/utils/__tests__/udsMessaging.test.ts index ef943cb76b..392daa1acf 100644 --- a/src/utils/__tests__/udsMessaging.test.ts +++ b/src/utils/__tests__/udsMessaging.test.ts @@ -3,24 +3,31 @@ import { chmod, mkdir, mkdtemp, + readdir, rm, stat, symlink, unlink, + writeFile, } from 'node:fs/promises' +import { createHash } from 'node:crypto' import { createConnection, createServer } from 'node:net' import { dirname, join } from 'node:path' import { tmpdir } from 'node:os' import { drainInbox, + getDefaultUdsSocketPath, MAX_UDS_INBOX_ENTRIES, MAX_UDS_INBOX_BYTES, MAX_UDS_FRAME_BYTES, + MAX_UDS_CLIENTS, + formatUdsAddress, parseUdsTarget, sendUdsMessage, setOnEnqueue, startUdsMessaging, stopUdsMessaging, + UDS_AUTH_TIMEOUT_MS, } from '../udsMessaging.js' let previousConfigDir: string | undefined @@ -192,7 +199,7 @@ describe('UDS inbox retention', () => { try { const { isPeerAlive } = await import('../udsClient.js') - expect(await isPeerAlive(path)).toBe(false) + expect(await isPeerAlive(path, 3_000, 'test-token')).toBe(false) } finally { await closeServer(receiver) if (process.platform !== 'win32') { @@ -210,6 +217,12 @@ describe('UDS inbox retention', () => { ) }) + test('sendUdsMessage fails closed before connecting without an auth token', async () => { + await expect( + sendUdsMessage(socketPath('no-auth-token'), { type: 'text', data: 'x' }), + ).rejects.toThrow('without auth token') + }) + test('drained entries never expose the UDS auth token', async () => { const path = socketPath('strip-token') await startUdsMessaging(path, { isExplicit: true }) @@ -232,6 +245,7 @@ describe('UDS inbox retention', () => { await startUdsMessaging(path, { isExplicit: true }) const response = await new Promise((resolve, reject) => { + let responseText = '' const conn = createConnection(path, () => { conn.write(`${JSON.stringify({ type: 'text', data: 'bad' })}\n`) }) @@ -242,10 +256,10 @@ describe('UDS inbox retention', () => { conn.on('data', chunk => { const text = chunk.toString('utf-8') if (text.includes('\n')) { - conn.end() - resolve(text) + responseText = text } }) + conn.on('close', () => resolve(responseText)) conn.on('error', reject) }) @@ -253,6 +267,56 @@ describe('UDS inbox retention', () => { expect(drainInbox()).toEqual([]) }) + test('disconnects malformed JSON clients without enqueueing inbox work', async () => { + const path = socketPath('malformed-client') + await startUdsMessaging(path, { isExplicit: true }) + + const response = await new Promise((resolve, reject) => { + let responseText = '' + const conn = createConnection(path, () => { + conn.write('{not-json\n') + }) + conn.setTimeout(5_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for malformed frame close')) + }) + conn.on('data', chunk => { + responseText += chunk.toString('utf-8') + }) + conn.on('close', () => resolve(responseText)) + conn.on('error', reject) + }) + + const parsed = JSON.parse(response) + expect(parsed.type).toBe('error') + expect(parsed.data).toBe('invalid frame') + expect(drainInbox()).toEqual([]) + }) + + test('disconnects idle unauthenticated clients', async () => { + const path = socketPath('idle-client') + await startUdsMessaging(path, { isExplicit: true }) + + const response = await new Promise((resolve, reject) => { + let responseText = '' + const conn = createConnection(path) + conn.setTimeout(UDS_AUTH_TIMEOUT_MS + 2_000, () => { + conn.destroy() + reject(new Error('Timed out waiting for auth timeout close')) + }) + conn.on('data', chunk => { + responseText += chunk.toString('utf-8') + }) + conn.on('close', () => resolve(responseText)) + conn.on('error', reject) + }) + + const parsed = JSON.parse(response) + expect(parsed.type).toBe('error') + expect(parsed.data).toBe('authentication timeout') + expect(drainInbox()).toEqual([]) + }) + test('destroys oversized frames before enqueueing inbox work', async () => { const path = socketPath('oversized') await startUdsMessaging(path, { isExplicit: true }) @@ -272,6 +336,14 @@ describe('UDS inbox retention', () => { expect(drainInbox()).toEqual([]) }) + test('default socket path is regenerated after stop', async () => { + const firstPath = getDefaultUdsSocketPath() + await startUdsMessaging(firstPath) + await stopUdsMessaging() + + expect(getDefaultUdsSocketPath()).not.toBe(firstPath) + }) + test('rejects oversized receiver responses before retaining them', async () => { const path = socketPath('oversized-response') if (process.platform !== 'win32') { @@ -303,9 +375,71 @@ describe('UDS inbox retention', () => { } }) + test('rejects closed receiver responses without waiting for timeout', async () => { + const path = socketPath('closed-response') + if (process.platform !== 'win32') { + await mkdir(dirname(path), { recursive: true }) + } + const receiver = createServer(socket => { + socket.end() + }) + await new Promise((resolve, reject) => { + receiver.on('error', reject) + receiver.listen(path, () => resolve()) + }) + + try { + await expect( + sendUdsMessage( + path, + { type: 'text', data: 'hello' }, + { authToken: 'test-token' }, + ), + ).rejects.toThrow('before response') + } finally { + await closeServer(receiver) + if (process.platform !== 'win32') { + await unlink(path).catch(() => undefined) + } + } + }) + + test('rejects malformed receiver responses without waiting for timeout', async () => { + const path = socketPath('malformed-response') + if (process.platform !== 'win32') { + await mkdir(dirname(path), { recursive: true }) + } + const receiver = createServer(socket => { + socket.on('data', () => { + socket.write('{not-json\n') + }) + }) + await new Promise((resolve, reject) => { + receiver.on('error', reject) + receiver.listen(path, () => resolve()) + }) + + try { + await expect( + sendUdsMessage( + path, + { type: 'text', data: 'hello' }, + { authToken: 'test-token' }, + ), + ).rejects.toThrow('Invalid UDS response frame') + } finally { + await closeServer(receiver) + if (process.platform !== 'win32') { + await unlink(path).catch(() => undefined) + } + } + }) + test('rejects inline auth token UDS targets instead of parsing them', async () => { const path = socketPath('inline-token') + expect(formatUdsAddress(path)).toBe(`uds:${path}`) + const targetWithToken = `${path}#token=secret` expect(() => parseUdsTarget(targetWithToken)).toThrow('inline auth token') try { @@ -320,6 +454,23 @@ describe('UDS inbox retention', () => { ) }) + test('fails closed and cleans temp files when capability target is occupied', async () => { + const path = socketPath('capability-target-dir') + const capabilityDir = join(tempConfigDir, 'messaging-capabilities') + const capabilityName = `${createHash('sha256').update(path).digest('hex')}.json` + await mkdir(join(capabilityDir, capabilityName), { + recursive: true, + mode: 0o700, + }) + + await expect( + startUdsMessaging(path, { isExplicit: true }), + ).rejects.toThrow() + + expect(process.env.CLAUDE_CODE_MESSAGING_SOCKET).toBeUndefined() + expect(await readdir(capabilityDir)).toEqual([capabilityName]) + }) + if (process.platform !== 'win32') { test('creates the listening socket with owner-only permissions', async () => { const path = socketPath('socket-mode') @@ -341,9 +492,11 @@ describe('UDS inbox retention', () => { await chmod(capabilityDir, 0o755) try { + const path = socketPath('broad-capdir') await expect( - startUdsMessaging(socketPath('broad-capdir'), { isExplicit: true }), + startUdsMessaging(path, { isExplicit: true }), ).rejects.toThrow('permissions are too broad') + await expect(stat(path)).rejects.toThrow() } finally { if (previousConfigDir === undefined) { delete process.env.CLAUDE_CONFIG_DIR @@ -397,5 +550,65 @@ describe('UDS inbox retention', () => { await rm(parent, { recursive: true, force: true }) } }) + + test('fails closed when an explicit socket parent is a file', async () => { + const parentFile = join( + tmpdir(), + `uds-socket-parent-file-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + ) + await writeFile(parentFile, 'not a directory', 'utf-8') + + try { + await expect( + startUdsMessaging(join(parentFile, 'messaging.sock'), { + isExplicit: true, + }), + ).rejects.toThrow('socket parent is not a directory') + } finally { + await rm(parentFile, { force: true }) + } + }) + + test('stop tolerates an already removed socket path', async () => { + const path = socketPath('already-removed') + await startUdsMessaging(path, { isExplicit: true }) + await unlink(path) + + await stopUdsMessaging() + + expect(process.env.CLAUDE_CODE_MESSAGING_SOCKET).toBeUndefined() + }) + + test('rejects clients over the configured connection cap', async () => { + const path = socketPath('client-cap') + await startUdsMessaging(path, { isExplicit: true }) + const sockets: ReturnType[] = [] + + try { + for (let i = 0; i < MAX_UDS_CLIENTS; i++) { + const socket = await new Promise>( + (resolve, reject) => { + const conn = createConnection(path, () => resolve(conn)) + conn.on('error', reject) + }, + ) + sockets.push(socket) + } + + await new Promise((resolve, reject) => { + const extra = createConnection(path) + extra.on('close', () => resolve()) + extra.on('error', reject) + extra.setTimeout(5_000, () => { + extra.destroy() + reject(new Error('Timed out waiting for client cap close')) + }) + }) + } finally { + for (const socket of sockets) { + socket.destroy() + } + } + }) } }) diff --git a/src/utils/__tests__/udsResponseReader.test.ts b/src/utils/__tests__/udsResponseReader.test.ts new file mode 100644 index 0000000000..71203da621 --- /dev/null +++ b/src/utils/__tests__/udsResponseReader.test.ts @@ -0,0 +1,171 @@ +import { describe, expect, test } from 'bun:test' +import { EventEmitter } from 'node:events' +import type { Socket } from 'node:net' +import { attachUdsResponseReader } from '../udsResponseReader.js' + +class FakeSocket extends EventEmitter { + destroyed = false + ended = false + + destroy(): this { + this.destroyed = true + this.emit('close', true) + return this + } + + end(): this { + this.ended = true + this.emit('close', false) + return this + } + + emitData(chunk: Buffer): void { + this.emit('data', chunk) + } +} + +function asSocket(socket: FakeSocket): Socket { + return socket as unknown as Socket +} + +describe('attachUdsResponseReader', () => { + test('tracks byte limits across split multibyte response chunks', () => { + const socket = new FakeSocket() + let settled = false + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settled = true + settledError = error + }, + }) + + const multibyte = String.fromCodePoint(0x20ac) + const frame = Buffer.from( + JSON.stringify({ type: 'response', data: `ok ${multibyte}` }) + '\n', + 'utf8', + ) + const multibyteStart = frame.indexOf(Buffer.from(multibyte, 'utf8')[0]) + + socket.emitData(frame.subarray(0, multibyteStart + 1)) + expect(settled).toBe(false) + + socket.emitData(frame.subarray(multibyteStart + 1)) + expect(settled).toBe(true) + expect(settledError).toBeUndefined() + expect(socket.ended).toBe(true) + }) + + test('rejects malformed response frames immediately', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + }) + + socket.emitData(Buffer.from('{bad-json}\n')) + + expect(settledError?.message).toBe('Invalid UDS response frame') + expect(socket.destroyed).toBe(true) + }) + + test('skips blank frames before a valid response', () => { + const socket = new FakeSocket() + let settled = false + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settled = true + settledError = error + }, + }) + + socket.emitData(Buffer.from('\n \n')) + expect(settled).toBe(false) + + socket.emitData(Buffer.from(`${JSON.stringify({ type: 'response' })}\n`)) + expect(settled).toBe(true) + expect(settledError).toBeUndefined() + expect(socket.ended).toBe(true) + }) + + test('rejects receiver error frames', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + }) + + socket.emitData( + Buffer.from(`${JSON.stringify({ type: 'error', data: 'denied' })}\n`), + ) + + expect(settledError?.message).toBe('denied') + expect(socket.destroyed).toBe(true) + }) + + test('uses custom socket error formatting', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + formatSocketError: error => + new Error(`wrapped:${(error as Error).message}`), + }) + + socket.emit('error', new Error('connect failed')) + + expect(settledError?.message).toBe('wrapped:connect failed') + expect(socket.destroyed).toBe(true) + }) + + test('rejects socket end before response', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + }) + + socket.emit('end') + + expect(settledError?.message).toBe('UDS socket ended before response') + expect(socket.destroyed).toBe(true) + }) + + test('rejects clean socket close before response', () => { + const socket = new FakeSocket() + let settledError: Error | undefined + + attachUdsResponseReader(asSocket(socket), { + maxFrameBytes: 128, + onSettled: error => { + settledError = error + }, + }) + + socket.emit('close', false) + + expect(settledError?.message).toBe('UDS socket closed before response') + expect(socket.destroyed).toBe(true) + }) +}) diff --git a/src/utils/messages/systemInit.ts b/src/utils/messages/systemInit.ts index 4585c78172..fcb9e74d11 100644 --- a/src/utils/messages/systemInit.ts +++ b/src/utils/messages/systemInit.ts @@ -87,10 +87,8 @@ export function buildSystemInitMessage(inputs: SystemInitInputs): SDKMessage { // Hidden from public SDK types — ant-only UDS messaging socket path if (feature('UDS_INBOX')) { /* eslint-disable @typescript-eslint/no-require-imports */ - const udsMessaging = - require('../udsMessaging.js') as typeof import('../udsMessaging.js') ;(initMessage as Record).messaging_socket_path = - udsMessaging.getUdsMessagingSocketPath() + require('../udsMessaging.js').getUdsMessagingSocketPath() /* eslint-enable @typescript-eslint/no-require-imports */ } initMessage.fast_mode_state = getFastModeState(inputs.model, inputs.fastMode) diff --git a/src/utils/ndjsonFramer.ts b/src/utils/ndjsonFramer.ts index 7832e93036..69717fc11e 100644 --- a/src/utils/ndjsonFramer.ts +++ b/src/utils/ndjsonFramer.ts @@ -10,11 +10,15 @@ import type { Socket } from 'net' export type NdjsonFramerOptions = { maxFrameBytes?: number onFrameError?: (error: Error) => void + destroyOnFrameError?: boolean + onInvalidFrame?: (error: Error) => void + destroyOnInvalidFrame?: boolean } /** * Attach an NDJSON framer to a socket. Calls `onMessage` for each - * complete JSON line received. Malformed lines are silently skipped. + * complete JSON line received. Malformed lines are skipped by default; + * callers may opt into error callbacks or socket destruction. * * @param parse - Optional custom JSON parser (defaults to JSON.parse). * Useful when the caller uses a wrapped parser like jsonParse @@ -35,15 +39,26 @@ export function attachNdjsonFramer( `NDJSON frame exceeded ${maxFrameBytes} bytes (${bytes})`, ) options.onFrameError?.(error) - socket.destroy(error) + if (options.destroyOnFrameError ?? true) { + socket.destroy(error) + } + } + + const rejectInvalidFrame = (error: unknown): void => { + const frameError = + error instanceof Error ? error : new Error('Invalid NDJSON frame') + options.onInvalidFrame?.(frameError) + if (options.destroyOnInvalidFrame ?? false) { + socket.destroy(frameError) + } } const emitLine = (line: string): void => { if (!line.trim()) return try { onMessage(parse(line)) - } catch { - // Malformed JSON — skip + } catch (error) { + rejectInvalidFrame(error) } } diff --git a/src/utils/swarm/inProcessRunner.ts b/src/utils/swarm/inProcessRunner.ts index eaab58ef76..06fde705a4 100644 --- a/src/utils/swarm/inProcessRunner.ts +++ b/src/utils/swarm/inProcessRunner.ts @@ -1246,13 +1246,8 @@ export async function runInProcessTeammate( // Track in-progress tool use IDs for animation in transcript view let inProgressToolUseIDs = task.inProgressToolUseIDs if (message.type === 'assistant') { - for (const block of Array.isArray(message.message!.content) - ? message.message!.content - : []) { - if ( - typeof block !== 'string' && - block.type === 'tool_use' - ) { + for (const block of (Array.isArray(message.message!.content) ? message.message!.content : [])) { + if (typeof block !== 'string' && block.type === 'tool_use') { inProgressToolUseIDs = new Set([ ...(inProgressToolUseIDs ?? []), block.id, @@ -1323,10 +1318,7 @@ export async function runInProcessTeammate( setAppState, ) if (currentAutonomyRunId) { - await markAutonomyRunFailed( - currentAutonomyRunId, - ERROR_MESSAGE_USER_ABORT, - ) + await markAutonomyRunFailed(currentAutonomyRunId, ERROR_MESSAGE_USER_ABORT) currentAutonomyRunId = undefined } } else if (currentAutonomyRunId) { diff --git a/src/utils/udsMessaging.ts b/src/utils/udsMessaging.ts index 94b73dcd6a..b30cba1378 100644 --- a/src/utils/udsMessaging.ts +++ b/src/utils/udsMessaging.ts @@ -82,6 +82,8 @@ export const MAX_UDS_INBOX_ENTRIES = 1_000 export const MAX_UDS_FRAME_BYTES = 64 * 1024 export const MAX_UDS_INBOX_BYTES = 2 * 1024 * 1024 export const MAX_UDS_CLIENTS = 128 +export const UDS_AUTH_TIMEOUT_MS = 2_000 +export const UDS_IDLE_TIMEOUT_MS = 30_000 // --------------------------------------------------------------------------- // Public API — socket path helpers @@ -339,6 +341,43 @@ function writeSocketMessage(socket: Socket, message: UdsMessage): void { socket.write(jsonStringify(message) + '\n') } +function writeSocketMessageAndDestroy(socket: Socket, message: UdsMessage): void { + if (socket.destroyed) return + socket.write(jsonStringify(message) + '\n', () => { + if (!socket.destroyed) socket.destroy() + }) +} + +function writeSocketErrorAndDestroy(socket: Socket, data: string): void { + writeSocketMessageAndDestroy(socket, { + type: 'error', + data, + ts: new Date().toISOString(), + }) +} + +function unrefTimer(timer: ReturnType): void { + const maybeUnref = (timer as { unref?: () => void }).unref + if (typeof maybeUnref === 'function') { + maybeUnref.call(timer) + } +} + +async function closeServer(serverToClose: Server): Promise { + await new Promise(resolve => { + serverToClose.close(() => resolve()) + }) +} + +async function removeSocketPath(path: string): Promise { + if (process.platform === 'win32') return + try { + await unlink(path) + } catch { + // Already gone. + } +} + function stripAuthToken(message: UdsMessage): UdsMessage { const { authToken: _authToken, ...metaWithoutAuth } = message.meta ?? {} return { @@ -391,10 +430,9 @@ export async function startUdsMessaging( } const token = ensureAuthToken() + let startedServer: Server | null = null + let exportedSocketEnv = false try { - await writeCapabilityFile(path, token) - socketPath = path - await new Promise((resolve, reject) => { const srv = createServer(socket => { if (clients.size >= MAX_UDS_CLIENTS) { @@ -408,6 +446,24 @@ export async function startUdsMessaging( logForDebugging( `[udsMessaging] client connected (total: ${clients.size})`, ) + let authenticated = false + let closing = false + const closeWithError = (data: string): void => { + if (closing || socket.destroyed) return + closing = true + socket.pause() + writeSocketErrorAndDestroy(socket, data) + } + const authTimer = setTimeout(() => { + if (authenticated || socket.destroyed) return + logForDebugging('[udsMessaging] closing unauthenticated idle client') + closeWithError('authentication timeout') + }, UDS_AUTH_TIMEOUT_MS) + unrefTimer(authTimer) + socket.setTimeout(UDS_IDLE_TIMEOUT_MS, () => { + logForDebugging('[udsMessaging] closing idle client') + closeWithError('idle timeout') + }) attachNdjsonFramer( socket, @@ -416,17 +472,13 @@ export async function startUdsMessaging( logForDebugging( `[udsMessaging] rejected unauthenticated message type=${msg.type}`, ) - if (!socket.destroyed) { - socket.write( - jsonStringify({ - type: 'error', - data: 'unauthorized', - ts: new Date().toISOString(), - } satisfies UdsMessage) + '\n', - ) - } + closeWithError('unauthorized') return } + if (!authenticated) { + authenticated = true + clearTimeout(authTimer) + } // Handle ping with automatic pong if (msg.type === 'ping') { @@ -447,11 +499,7 @@ export async function startUdsMessaging( status: 'pending', } if (!enqueueInboxEntry(entry)) { - writeSocketMessage(socket, { - type: 'error', - data: 'inbox full', - ts: new Date().toISOString(), - }) + closeWithError('inbox full') return } logForDebugging( @@ -470,21 +518,40 @@ export async function startUdsMessaging( maxFrameBytes: MAX_UDS_FRAME_BYTES, onFrameError: error => { logForDebugging(`[udsMessaging] ${error.message}`) + closeWithError(error.message) + }, + onInvalidFrame: error => { + logForDebugging( + `[udsMessaging] invalid client frame: ${errorMessage(error)}`, + ) + closeWithError('invalid frame') }, + destroyOnFrameError: false, }, ) socket.on('close', () => { + clearTimeout(authTimer) clients.delete(socket) }) socket.on('error', err => { + clearTimeout(authTimer) clients.delete(socket) logForDebugging(`[udsMessaging] client error: ${errorMessage(err)}`) }) }) - srv.on('error', reject) + const rejectBeforeListen = (error: Error): void => { + reject(error) + } + const logRuntimeError = (error: Error): void => { + logForDebugging( + `[udsMessaging] server error on ${path}${opts?.isExplicit ? ' (explicit)' : ''}: ${errorMessage(error)}`, + ) + } + + srv.once('error', rejectBeforeListen) srv.listen(path, () => { void (async () => { @@ -492,19 +559,41 @@ export async function startUdsMessaging( if (process.platform !== 'win32') { await chmod(path, 0o600) } + srv.off('error', rejectBeforeListen) + srv.on('error', logRuntimeError) server = srv - // Export so child processes can discover the socket - process.env.CLAUDE_CODE_MESSAGING_SOCKET = path - logForDebugging( - `[udsMessaging] server listening on ${path}${opts?.isExplicit ? ' (explicit)' : ''}`, - ) + startedServer = srv resolve() } catch (error) { - srv.close(() => reject(error)) + srv.off('error', rejectBeforeListen) + const closeError = + error instanceof Error ? error : new Error(errorMessage(error)) + let rejected = false + const rejectOnce = (): void => { + if (rejected) return + rejected = true + reject(closeError) + } + const fallback = setTimeout(rejectOnce, 1_000) + unrefTimer(fallback) + srv.close(() => { + clearTimeout(fallback) + rejectOnce() + }) } })() }) }) + + await writeCapabilityFile(path, token) + socketPath = path + // Export so child processes can discover the socket only after the + // capability file exists and the listener is ready. + process.env.CLAUDE_CODE_MESSAGING_SOCKET = path + exportedSocketEnv = true + logForDebugging( + `[udsMessaging] server listening on ${path}${opts?.isExplicit ? ' (explicit)' : ''}`, + ) } catch (error) { if (capabilityFilePath) { try { @@ -514,7 +603,18 @@ export async function startUdsMessaging( } capabilityFilePath = null } + if (startedServer) { + await closeServer(startedServer) + } + if (server === startedServer) { + server = null + } + await removeSocketPath(path) + if (exportedSocketEnv) { + delete process.env.CLAUDE_CODE_MESSAGING_SOCKET + } socketPath = null + defaultSocketPath = null authToken = null throw error } @@ -529,6 +629,7 @@ export async function startUdsMessaging( * Stop the UDS messaging server and clean up the socket file. */ export async function stopUdsMessaging(): Promise { + defaultSocketPath = null if (!server) return // Close all connected clients @@ -547,13 +648,7 @@ export async function stopUdsMessaging(): Promise { // Remove socket file (skip on Windows — pipe paths aren't files) if (socketPath) { - if (process.platform !== 'win32') { - try { - await unlink(socketPath) - } catch { - // Already gone - } - } + await removeSocketPath(socketPath) delete process.env.CLAUDE_CODE_MESSAGING_SOCKET logForDebugging( `[udsMessaging] server stopped, socket removed: ${socketPath}`, diff --git a/src/utils/udsResponseReader.ts b/src/utils/udsResponseReader.ts index bb8d21f40f..d86328aabe 100644 --- a/src/utils/udsResponseReader.ts +++ b/src/utils/udsResponseReader.ts @@ -1,4 +1,5 @@ import type { Socket } from 'net' +import { StringDecoder } from 'node:string_decoder' import { errorMessage } from './errors.js' import { jsonParse } from './slowOperations.js' import type { UdsMessage } from './udsMessaging.js' @@ -16,11 +17,11 @@ export function getChunkBytes(chunk: string | Buffer): number { : chunk.byteLength } -function parseResponseLine(line: string): UdsMessage | null { +function parseResponseLine(line: string): UdsMessage { try { return jsonParse(line) as UdsMessage } catch { - return null + throw new Error('Invalid UDS response frame') } } @@ -29,35 +30,58 @@ export function attachUdsResponseReader( options: UdsResponseReaderOptions, ): void { let buffer = '' + let bufferBytes = 0 let settled = false + const decoder = new StringDecoder('utf8') - const finish = (error?: Error): void => { + function cleanupListeners(): void { + socket.off('data', onData) + socket.off('error', onError) + socket.off('end', onEnd) + socket.off('close', onClose) + } + + function finish(error?: Error): void { if (settled) return settled = true + buffer = '' + bufferBytes = 0 + cleanupListeners() if (error) { - socket.destroy(error) + socket.destroy() } else { socket.end() } options.onSettled(error) } - socket.on('data', chunk => { - if ( - Buffer.byteLength(buffer, 'utf8') + getChunkBytes(chunk) > - options.maxFrameBytes - ) { + function onData(chunk: Buffer): void { + const decoded = decoder.write(chunk) + const decodedBytes = Buffer.byteLength(decoded, 'utf8') + if (bufferBytes + decodedBytes > options.maxFrameBytes) { finish(new Error('UDS response frame exceeded size limit')) return } - buffer += chunk.toString() - const lines = buffer.split('\n') - buffer = lines.pop() ?? '' - for (const line of lines) { - if (!line.trim()) continue - const response = parseResponseLine(line) - if (!response) continue + buffer += decoded + bufferBytes += decodedBytes + let newlineIndex = buffer.indexOf('\n') + while (newlineIndex !== -1) { + const line = buffer.slice(0, newlineIndex) + const consumed = buffer.slice(0, newlineIndex + 1) + buffer = buffer.slice(newlineIndex + 1) + bufferBytes -= Buffer.byteLength(consumed, 'utf8') + if (!line.trim()) { + newlineIndex = buffer.indexOf('\n') + continue + } + let response: UdsMessage + try { + response = parseResponseLine(line) + } catch (error) { + finish(error instanceof Error ? error : new Error(errorMessage(error))) + return + } if ( response.type === 'response' || (options.acceptPong === true && response.type === 'pong') @@ -69,13 +93,28 @@ export function attachUdsResponseReader( finish(new Error(response.data ?? 'UDS receiver rejected message')) return } + newlineIndex = buffer.indexOf('\n') } - }) + } - socket.on('error', error => { + function onError(error: Error): void { finish( options.formatSocketError?.(error) ?? (error instanceof Error ? error : new Error(errorMessage(error))), ) - }) + } + + function onEnd(): void { + finish(new Error('UDS socket ended before response')) + } + + function onClose(hadError: boolean): void { + if (hadError) return + finish(new Error('UDS socket closed before response')) + } + + socket.on('data', onData) + socket.on('error', onError) + socket.on('end', onEnd) + socket.on('close', onClose) }