Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fix-malformed-tool-call-input.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@workflow/ai': patch
---

Preserve malformed streamed tool-call input until repair hooks can run
79 changes: 78 additions & 1 deletion packages/ai/src/agent/do-stream-step.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import type { UIMessageChunk } from 'ai';
import { describe, expect, it } from 'vitest';
import { normalizeFinishReason } from './do-stream-step.js';
import { doStreamStep, normalizeFinishReason } from './do-stream-step.js';
import { safeParseToolCallInput } from './safe-parse-tool-call-input.js';

describe('normalizeFinishReason', () => {
describe('string finish reasons', () => {
Expand Down Expand Up @@ -122,3 +124,78 @@ describe('normalizeFinishReason', () => {
});
});
});

describe('safeParseToolCallInput', () => {
it('should parse valid JSON input', () => {
expect(safeParseToolCallInput('{"city":"San Francisco"}')).toEqual({
city: 'San Francisco',
});
});

it('should return empty object for undefined input', () => {
expect(safeParseToolCallInput(undefined)).toEqual({});
});

it('should preserve malformed input as a string', () => {
expect(safeParseToolCallInput('{"city":"San Francisco"')).toBe(
'{"city":"San Francisco"'
);
});
});

describe('doStreamStep', () => {
it('should not throw when streamed tool-call input is malformed JSON', async () => {
const writtenChunks: UIMessageChunk[] = [];
const writable = new WritableStream<UIMessageChunk>({
write: async (chunk) => {
writtenChunks.push(chunk);
},
});

const model = {
provider: 'mock-provider',
modelId: 'mock-model',
doStream: async () => ({
stream: new ReadableStream({
start(controller) {
controller.enqueue({ type: 'stream-start', warnings: [] });
controller.enqueue({
type: 'tool-call',
toolCallId: 'call-1',
toolName: 'getWeather',
input: '{"city":"San Francisco"',
});
controller.enqueue({
type: 'finish',
finishReason: 'tool-calls',
usage: {
inputTokens: { total: 10, noCache: 10 },
outputTokens: { total: 5, text: 0, reasoning: 5 },
},
});
controller.close();
},
}),
}),
};

const result = await doStreamStep(
[{ role: 'user', content: [{ type: 'text', text: 'test' }] }],
async () => model as any,
writable,
undefined,
{ sendStart: false }
);

expect(result.step.toolCalls).toHaveLength(1);
expect(result.step.toolCalls[0]?.input).toBe('{"city":"San Francisco"');
expect(writtenChunks).toContainEqual(
expect.objectContaining({
type: 'tool-input-available',
toolCallId: 'call-1',
toolName: 'getWeather',
input: '{"city":"San Francisco"',
})
);
});
});
43 changes: 14 additions & 29 deletions packages/ai/src/agent/do-stream-step.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import type {
TelemetrySettings,
} from './durable-agent.js';
import { getErrorMessage } from '../get-error-message.js';
import { safeParseToolCallInput } from './safe-parse-tool-call-input.js';
import { recordSpan } from './telemetry.js';
import type { CompatibleLanguageModel } from './types.js';

Expand Down Expand Up @@ -454,7 +455,7 @@ export async function doStreamStep(
type: 'tool-input-available',
toolCallId: part.toolCallId,
toolName: part.toolName,
input: JSON.parse(part.input || '{}'),
input: safeParseToolCallInput(part.input),
...(part.providerExecuted != null
? { providerExecuted: part.providerExecuted }
: {}),
Expand Down Expand Up @@ -779,6 +780,14 @@ function chunksToStep(
? v3FinishReason
: undefined;

const mapToolCall = (toolCall: LanguageModelV3ToolCall) => ({
type: 'tool-call' as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
input: safeParseToolCallInput(toolCall.input),
dynamic: true as const,
});

const stepResult: StepResult<any> = {
stepNumber: 0, // Will be overridden by the caller
model: {
Expand All @@ -790,13 +799,7 @@ function chunksToStep(
experimental_context: undefined,
content: [
...(text ? [{ type: 'text' as const, text }] : []),
...toolCalls.map((toolCall) => ({
type: 'tool-call' as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
input: JSON.parse(toolCall.input),
dynamic: true as const,
})),
...toolCalls.map(mapToolCall),
],
text,
reasoning: reasoning.map((r) => ({
Expand All @@ -809,21 +812,9 @@ function chunksToStep(
reasoningText: reasoningText || undefined,
files,
sources,
toolCalls: toolCalls.map((toolCall) => ({
type: 'tool-call' as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
input: JSON.parse(toolCall.input),
dynamic: true as const,
})),
toolCalls: toolCalls.map(mapToolCall),
staticToolCalls: [],
dynamicToolCalls: toolCalls.map((toolCall) => ({
type: 'tool-call' as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
input: JSON.parse(toolCall.input),
dynamic: true as const,
})),
dynamicToolCalls: toolCalls.map(mapToolCall),
toolResults: [],
staticToolResults: [],
dynamicToolResults: [],
Expand Down Expand Up @@ -864,13 +855,7 @@ function chunksToStep(
request: {
body: JSON.stringify({
prompt: conversationPrompt,
tools: toolCalls.map((toolCall) => ({
type: 'tool-call' as const,
toolCallId: toolCall.toolCallId,
toolName: toolCall.toolName,
input: JSON.parse(toolCall.input),
dynamic: true as const,
})),
tools: toolCalls.map(mapToolCall),
}),
},
response: {
Expand Down
86 changes: 86 additions & 0 deletions packages/ai/src/agent/durable-agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2446,6 +2446,92 @@ describe('DurableAgent', () => {
})
);
});

it('should patch repaired tool-call input back into the conversation prompt', async () => {
const repairFn: ToolCallRepairFunction<ToolSet> = vi
.fn()
.mockReturnValue({
toolCallId: 'test-call-id',
toolName: 'testTool',
input: '{"name":"repaired"}',
});

const tools: ToolSet = {
testTool: {
description: 'A test tool',
inputSchema: z.object({ name: z.string() }),
execute: async () => ({ result: 'success' }),
},
};

const mockModel = createMockModel();

const agent = new DurableAgent({
model: async () => mockModel,
tools,
});

const mockWritable = new WritableStream({
write: vi.fn(),
close: vi.fn(),
});

const mockMessages: LanguageModelV3Prompt = [
{ role: 'user', content: [{ type: 'text', text: 'test' }] },
{
role: 'assistant',
content: [
{
type: 'tool-call',
toolCallId: 'test-call-id',
toolName: 'testTool',
input: 'invalid json',
},
],
},
];

const { streamTextIterator } = await import('./stream-text-iterator.js');
const mockIterator = {
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: {
toolCalls: [
{
toolCallId: 'test-call-id',
toolName: 'testTool',
input: 'invalid json',
} as LanguageModelV3ToolCall,
],
messages: mockMessages,
},
})
.mockResolvedValueOnce({ done: true, value: [] }),
};
vi.mocked(streamTextIterator).mockReturnValue(
mockIterator as unknown as MockIterator
);

await agent.stream({
messages: [{ role: 'user', content: 'test' }],
writable: mockWritable,
experimental_repairToolCall: repairFn,
});

expect(mockMessages[1]).toMatchObject({
role: 'assistant',
content: [
{
type: 'tool-call',
toolCallId: 'test-call-id',
toolName: 'testTool',
input: { name: 'repaired' },
},
],
});
});
});

describe('includeRawChunks', () => {
Expand Down
53 changes: 40 additions & 13 deletions packages/ai/src/agent/durable-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {
} from 'ai';
import { convertToLanguageModelPrompt, standardizePrompt } from 'ai/internal';
import { getErrorMessage } from '../get-error-message.js';
import { safeParseToolCallInput } from './safe-parse-tool-call-input.js';
import { streamTextIterator } from './stream-text-iterator.js';
import { recordSpan, runInContext } from './telemetry.js';
import type { CompatibleLanguageModel } from './types.js';
Expand Down Expand Up @@ -1133,15 +1134,15 @@ export class DurableAgent<TBaseTools extends ToolSet = ToolSet> {
type: 'tool-call' as const,
toolCallId: tc.toolCallId,
toolName: tc.toolName,
input: safeParseInput(tc.input),
input: safeParseToolCallInput(tc.input),
}));

// Build toolResults only for tools that were executed
const allToolResults: ToolResult[] = resolvedResults.map((r) => ({
type: 'tool-result' as const,
toolCallId: r.toolCallId,
toolName: r.toolName,
input: safeParseInput(
input: safeParseToolCallInput(
toolCalls.find((tc) => tc.toolCallId === r.toolCallId)?.input
),
output: 'value' in r.output ? r.output.value : undefined,
Expand Down Expand Up @@ -1244,13 +1245,13 @@ export class DurableAgent<TBaseTools extends ToolSet = ToolSet> {
type: 'tool-call' as const,
toolCallId: tc.toolCallId,
toolName: tc.toolName,
input: safeParseInput(tc.input),
input: safeParseToolCallInput(tc.input),
}));
lastStepToolResults = toolResults.map((r) => ({
type: 'tool-result' as const,
toolCallId: r.toolCallId,
toolName: r.toolName,
input: safeParseInput(
input: safeParseToolCallInput(
toolCalls.find((tc) => tc.toolCallId === r.toolCallId)?.input
),
output: 'value' in r.output ? r.output.value : undefined,
Expand Down Expand Up @@ -1478,10 +1479,6 @@ async function convertChunksToUIMessages(
return messages;
}

/**
* Safely parse tool call input JSON. Returns the parsed value or the raw string
* if parsing fails (e.g., for tool calls that were repaired).
*/
/**
* Valid `type` values for LanguageModelV3ToolResultOutput.
* When a tool returns an object whose `type` matches one of these,
Expand All @@ -1503,11 +1500,35 @@ function isToolResultOutput(
return TOOL_RESULT_OUTPUT_TYPES.has((result as { type?: string }).type ?? '');
}

function safeParseInput(input: string | undefined): unknown {
try {
return JSON.parse(input || '{}');
} catch {
return input;
function patchToolCallInMessages(
messages: LanguageModelV3Prompt,
toolCall: LanguageModelV3ToolCall
): void {
const repairedInput = safeParseToolCallInput(toolCall.input);

for (let i = messages.length - 1; i >= 0; i--) {
const message = messages[i];

if (message.role !== 'assistant' || !Array.isArray(message.content)) {
continue;
}

const toolCallPart = message.content.find(
(
part
): part is {
type: 'tool-call';
toolCallId: string;
toolName: string;
input: unknown;
} => part.type === 'tool-call' && part.toolCallId === toolCall.toolCallId
);

if (toolCallPart) {
toolCallPart.toolName = toolCall.toolName;
toolCallPart.input = repairedInput;
return;
}
}
}

Expand Down Expand Up @@ -1589,6 +1610,9 @@ async function executeTool(
messages,
});
if (repairedToolCall) {
toolCall.toolName = repairedToolCall.toolName;
toolCall.input = repairedToolCall.input;
patchToolCallInMessages(messages, repairedToolCall);
// Retry with repaired tool call
return executeTool(
repairedToolCall,
Expand All @@ -1614,6 +1638,9 @@ async function executeTool(
messages,
});
if (repairedToolCall) {
toolCall.toolName = repairedToolCall.toolName;
toolCall.input = repairedToolCall.input;
patchToolCallInMessages(messages, repairedToolCall);
// Retry with repaired tool call
return executeTool(
repairedToolCall,
Expand Down
Loading
Loading