Skip to content

Commit 531eb67

Browse files
noorbhatiamattt
andauthored
Fix incorrect formatting in OpenAI responses API for tool (#125)
* Fix format mismatch in OpenAI responses API * Rewrite expressions to improve clarity * Add unit tests for OpenAI tool calling formatting --------- Co-authored-by: Mattt Zmuda <mattt@me.com>
1 parent d589b79 commit 531eb67

File tree

2 files changed

+175
-49
lines changed

2 files changed

+175
-49
lines changed

Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,59 +1050,60 @@ private enum Responses {
10501050
outputs.append(object)
10511051

10521052
case .tool(let id):
1053-
let toolMessage = msg
1054-
// Wrap user content into a single top-level message as required by Responses API
1055-
var contentBlocks: [JSONValue]
1056-
switch toolMessage.content {
1053+
let outputValue: JSONValue
1054+
switch msg.content {
10571055
case .text(let text):
1058-
contentBlocks = [
1059-
.object(["type": .string("input_text"), "text": .string(text)])
1060-
]
1056+
outputValue = .string(text)
10611057
case .blocks(let blocks):
1062-
contentBlocks = blocks.map { block in
1063-
switch block {
1064-
case .text(let text):
1065-
return .object(["type": .string("input_text"), "text": .string(text)])
1066-
case .imageURL(let url):
1067-
return .object([
1068-
"type": .string("input_image"),
1069-
"image_url": .object(["url": .string(url)]),
1070-
])
1058+
outputValue = .array(
1059+
blocks.map { block in
1060+
switch block {
1061+
case .text(let text):
1062+
return .object(["type": .string("input_text"), "text": .string(text)])
1063+
case .imageURL(let url):
1064+
return .object([
1065+
"type": .string("input_image"),
1066+
"image_url": .string(url),
1067+
])
1068+
}
10711069
}
1072-
}
1073-
}
1074-
let outputString: String
1075-
if contentBlocks.count > 1 {
1076-
let encoder = JSONEncoder()
1077-
if let data = try? encoder.encode(JSONValue.array(contentBlocks)),
1078-
let str = String(data: data, encoding: .utf8)
1079-
{
1080-
outputString = str
1081-
} else {
1082-
outputString = "[]"
1083-
}
1084-
} else if let block = contentBlocks.first {
1085-
let encoder = JSONEncoder()
1086-
if let data = try? encoder.encode(block),
1087-
let str = String(data: data, encoding: .utf8)
1088-
{
1089-
outputString = str
1090-
} else {
1091-
outputString = "{}"
1092-
}
1093-
} else {
1094-
outputString = "{}"
1070+
)
10951071
}
10961072
outputs.append(
10971073
.object([
10981074
"type": .string("function_call_output"),
10991075
"call_id": .string(id),
1100-
"output": .string(outputString),
1076+
"output": outputValue,
11011077
])
11021078
)
11031079

11041080
case .raw(rawContent: let rawContent):
1105-
outputs.append(rawContent)
1081+
// Convert Chat Completions assistant+tool_calls to Responses API function_call items
1082+
if case .object(let assistantMessageObject) = rawContent,
1083+
case .string(let messageRole) = assistantMessageObject["role"],
1084+
messageRole == "assistant",
1085+
case .array(let assistantToolCalls) = assistantMessageObject["tool_calls"]
1086+
{
1087+
for assistantToolCall in assistantToolCalls {
1088+
if case .object(let toolCallObject) = assistantToolCall,
1089+
case .string(let toolCallID) = toolCallObject["id"],
1090+
case .object(let functionCallObject) = toolCallObject["function"],
1091+
case .string(let functionName) = functionCallObject["name"],
1092+
case .string(let functionArguments) = functionCallObject["arguments"]
1093+
{
1094+
outputs.append(
1095+
.object([
1096+
"type": .string("function_call"),
1097+
"call_id": .string(toolCallID),
1098+
"name": .string(functionName),
1099+
"arguments": .string(functionArguments),
1100+
])
1101+
)
1102+
}
1103+
}
1104+
} else {
1105+
outputs.append(rawContent)
1106+
}
11061107

11071108
case .system:
11081109
let systemMessage = msg

Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift

Lines changed: 133 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import Foundation
2+
import JSONSchema
23
import Testing
34

45
@testable import AnyLanguageModel
@@ -14,7 +15,6 @@ struct OpenAILanguageModelTests {
1415
}
1516

1617
@Test func apiVariantParameterization() throws {
17-
// Test that both API variants can be created and have correct properties
1818
for apiVariant in [OpenAILanguageModel.APIVariant.chatCompletions, .responses] {
1919
let model = OpenAILanguageModel(apiKey: "test-key", model: "test-model", apiVariant: apiVariant)
2020
#expect(model.apiVariant == apiVariant)
@@ -97,7 +97,6 @@ struct OpenAILanguageModelTests {
9797
maximumResponseTokens: 50
9898
)
9999

100-
// Set custom options (extraBody will be merged into the request)
101100
options[custom: OpenAILanguageModel.self] = .init(
102101
extraBody: ["user": .string("test-user-id")]
103102
)
@@ -138,19 +137,77 @@ struct OpenAILanguageModelTests {
138137
}
139138

140139
@Test func withTools() async throws {
141-
let weatherTool = WeatherTool()
140+
let weatherTool = spy(on: WeatherTool())
142141
let session = LanguageModelSession(model: model, tools: [weatherTool])
143142

144-
let response = try await session.respond(to: "How's the weather in San Francisco?")
143+
var options = GenerationOptions()
144+
options[custom: OpenAILanguageModel.self] = .init(
145+
maxToolCalls: 1
146+
)
147+
148+
let response = try await withOpenAIRateLimitRetry {
149+
try await session.respond(
150+
to: "Call getWeather for San Francisco exactly once, then summarize in one sentence.",
151+
options: options
152+
)
153+
}
154+
155+
#expect(!response.content.isEmpty)
156+
let calls = await weatherTool.calls
157+
#expect(!calls.isEmpty)
158+
if let firstCall = calls.first {
159+
#expect(firstCall.arguments.city.localizedCaseInsensitiveContains("san"))
160+
}
145161

162+
var foundToolCall = false
146163
var foundToolOutput = false
147-
for case let .toolOutput(toolOutput) in response.transcriptEntries {
148-
#expect(toolOutput.toolName == "getWeather")
149-
foundToolOutput = true
164+
for entry in response.transcriptEntries {
165+
switch entry {
166+
case .toolCalls(let toolCalls):
167+
#expect(!toolCalls.isEmpty)
168+
if let firstToolCall = toolCalls.first {
169+
#expect(firstToolCall.toolName == "getWeather")
170+
}
171+
foundToolCall = true
172+
case .toolOutput(let toolOutput):
173+
#expect(toolOutput.toolName == "getWeather")
174+
foundToolOutput = true
175+
default:
176+
break
177+
}
150178
}
179+
#expect(foundToolCall)
151180
#expect(foundToolOutput)
152181
}
153182

183+
@Test func withToolsConversationContinuesAcrossTurns() async throws {
184+
let weatherTool = spy(on: WeatherTool())
185+
let session = LanguageModelSession(model: model, tools: [weatherTool])
186+
187+
var options = GenerationOptions()
188+
options[custom: OpenAILanguageModel.self] = .init(
189+
maxToolCalls: 1
190+
)
191+
192+
_ = try await withOpenAIRateLimitRetry {
193+
try await session.respond(
194+
to: "Call getWeather for San Francisco exactly once, then reply with only: done",
195+
options: options
196+
)
197+
}
198+
199+
let secondResponse = try await withOpenAIRateLimitRetry {
200+
try await session.respond(
201+
to: "Which city did the tool call use? Reply with city only."
202+
)
203+
}
204+
#expect(!secondResponse.content.isEmpty)
205+
#expect(secondResponse.content.localizedCaseInsensitiveContains("san"))
206+
207+
let calls = await weatherTool.calls
208+
#expect(calls.count >= 1)
209+
}
210+
154211
@Suite("Structured Output")
155212
struct StructuredOutputTests {
156213
@Generable
@@ -316,7 +373,6 @@ struct OpenAILanguageModelTests {
316373
maximumResponseTokens: 50
317374
)
318375

319-
// Set custom options (extraBody will be merged into the request)
320376
options[custom: OpenAILanguageModel.self] = .init(
321377
extraBody: ["user": "test-user-id"]
322378
)
@@ -459,4 +515,73 @@ struct OpenAILanguageModelTests {
459515
}
460516
}
461517
}
518+
519+
@Suite("OpenAILanguageModel Responses Request Body")
520+
struct ResponsesRequestBodyTests {
521+
private let model = "test-model"
522+
523+
private func inputArray(from body: JSONValue) -> [JSONValue]? {
524+
guard case let .object(obj) = body else { return nil }
525+
guard case let .array(input)? = obj["input"] else { return nil }
526+
return input
527+
}
528+
529+
private func stringValue(_ value: JSONValue?) -> String? {
530+
guard case let .string(text)? = value else { return nil }
531+
return text
532+
}
533+
534+
private func firstObject(withType type: String, in input: [JSONValue]) -> [String: JSONValue]? {
535+
for value in input {
536+
guard case let .object(obj) = value else { continue }
537+
guard case let .string(foundType)? = obj["type"], foundType == type else { continue }
538+
return obj
539+
}
540+
return nil
541+
}
542+
543+
private func containsKey(_ value: JSONValue, key: String) -> Bool {
544+
guard case let .object(obj) = value else { return false }
545+
return obj[key] != nil
546+
}
547+
548+
private func makePrompt(_ text: String = "Continue.") -> Transcript.Prompt {
549+
Transcript.Prompt(segments: [.text(.init(content: text))])
550+
}
551+
552+
private func makeTranscriptWithToolCalls() throws -> Transcript {
553+
let arguments = try GeneratedContent(json: #"{"city":"Paris"}"#)
554+
let call = Transcript.ToolCall(id: "call-1", toolName: "getWeather", arguments: arguments)
555+
let toolCalls = Transcript.ToolCalls([call])
556+
return Transcript(entries: [
557+
.toolCalls(toolCalls),
558+
.prompt(makePrompt()),
559+
])
560+
}
561+
}
562+
}
563+
564+
private func withOpenAIRateLimitRetry<T>(
565+
maxAttempts: Int = 4,
566+
operation: @escaping () async throws -> T
567+
) async throws -> T {
568+
var attempt = 1
569+
while true {
570+
do {
571+
return try await operation()
572+
} catch let error as URLSessionError {
573+
if case .httpError(_, let detail) = error,
574+
detail.contains("rate_limit_exceeded"),
575+
attempt < maxAttempts
576+
{
577+
let delaySeconds = UInt64(attempt)
578+
try await Task.sleep(nanoseconds: delaySeconds * 1_000_000_000)
579+
attempt += 1
580+
continue
581+
}
582+
throw error
583+
} catch {
584+
throw error
585+
}
586+
}
462587
}

0 commit comments

Comments
 (0)