Skip to content

Commit df1685b

Browse files
committed
Add optional imageData for vision and DallE3 method
1 parent eaaec98 commit df1685b

File tree

1 file changed

+68
-6
lines changed

1 file changed

+68
-6
lines changed

Sources/ChatGPTSwift/ChatGPTAPI.swift

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,15 @@ public class ChatGPTAPI: @unchecked Sendable {
8686
public func sendMessageStream(text: String,
8787
model: ChatGPTModel = .gpt_hyphen_4o,
8888
systemText: String = ChatGPTAPI.Constants.defaultSystemText,
89-
temperature: Double = ChatGPTAPI.Constants.defaultTemperature) async throws -> AsyncMapSequence<AsyncThrowingPrefixWhileSequence<AsyncThrowingMapSequence<ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<HTTPBody>>, ServerSentEventWithJSONData<Components.Schemas.CreateChatCompletionStreamResponse>>>, String> {
89+
temperature: Double = ChatGPTAPI.Constants.defaultTemperature,
90+
imageData: Data? = nil) async throws -> AsyncMapSequence<AsyncThrowingPrefixWhileSequence<AsyncThrowingMapSequence<ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<HTTPBody>>, ServerSentEventWithJSONData<Components.Schemas.CreateChatCompletionStreamResponse>>>, String> {
91+
var messages = generateInternalMessages(from: text, systemText: systemText)
92+
if let imageData {
93+
messages.append(createMessage(imageData: imageData))
94+
}
95+
9096
let response = try await client.createChatCompletion(.init(headers: .init(accept: [.init(contentType: .text_event_hyphen_stream)]), body: .json(.init(
91-
messages: self.generateInternalMessages(from: text, systemText: systemText),
97+
messages: messages,
9298
model: .init(value1: nil, value2: model),
9399
stream: true))))
94100

@@ -124,10 +130,15 @@ public class ChatGPTAPI: @unchecked Sendable {
124130
public func sendMessage(text: String,
125131
model: ChatGPTModel = .gpt_hyphen_4o,
126132
systemText: String = ChatGPTAPI.Constants.defaultSystemText,
127-
temperature: Double = ChatGPTAPI.Constants.defaultTemperature) async throws -> String {
133+
temperature: Double = ChatGPTAPI.Constants.defaultTemperature,
134+
imageData: Data? = nil) async throws -> String {
135+
var messages = generateInternalMessages(from: text, systemText: systemText)
136+
if let imageData {
137+
messages.append(createMessage(imageData: imageData))
138+
}
128139

129140
let response = try await client.createChatCompletion(body: .json(.init(
130-
messages: self.generateInternalMessages(from: text, systemText: systemText),
141+
messages: messages,
131142
model: .init(value1: nil, value2: model))))
132143

133144
switch response {
@@ -146,10 +157,16 @@ public class ChatGPTAPI: @unchecked Sendable {
146157
public func callFunction(prompt: String,
147158
tools: [ChatCompletionTool],
148159
model: Components.Schemas.CreateChatCompletionRequest.modelPayload.Value2Payload = .gpt_hyphen_4,
149-
systemText: String = "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."
160+
systemText: String = "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
161+
imageData: Data? = nil
150162
) async throws -> ChatCompletionResponseMessage {
163+
var messages = generateInternalMessages(from: prompt, systemText: systemText)
164+
if let imageData {
165+
messages.append(createMessage(imageData: imageData))
166+
}
167+
151168
let response = try await client.createChatCompletion(.init(body: .json(.init(
152-
messages: generateInternalMessages(from: prompt, systemText: systemText),
169+
messages: messages,
153170
model: .init(value1: nil, value2: model),
154171
tools: tools,
155172
tool_choice: .none))))
@@ -233,6 +250,39 @@ public class ChatGPTAPI: @unchecked Sendable {
233250
}
234251
#endif
235252

253+
public func generateDallE3Image(prompt: String,
254+
quality: Components.Schemas.CreateImageRequest.qualityPayload = .standard,
255+
responseFormat: Components.Schemas.CreateImageRequest.response_formatPayload = .url,
256+
style: Components.Schemas.CreateImageRequest.stylePayload = .vivid
257+
258+
) async throws -> Components.Schemas.Image {
259+
260+
let response = try await client.createImage(.init(body: .json(
261+
.init(
262+
prompt: prompt,
263+
model: .init(value1: nil, value2: .dall_hyphen_e_hyphen_3),
264+
n: 1,
265+
quality: quality,
266+
response_format: responseFormat,
267+
size: ._1024x1024,
268+
style: style
269+
))))
270+
271+
switch response {
272+
case .ok(let response):
273+
switch response.body {
274+
case .json(let imageResponse) where imageResponse.data.first != nil:
275+
return imageResponse.data.first!
276+
277+
default:
278+
throw "Unknown response"
279+
}
280+
281+
case .undocumented(let statusCode, let payload):
282+
throw getError(statusCode: statusCode, model: Components.Schemas.CreateImageRequest.modelPayload.Value2Payload.dall_hyphen_e_hyphen_3.rawValue, payload: payload)
283+
}
284+
}
285+
236286
func getError(statusCode: Int, model: String?, payload: UndocumentedPayload?) -> Error {
237287
var error = "\(statusCode) - "
238288
if statusCode == 401 {
@@ -256,5 +306,17 @@ public class ChatGPTAPI: @unchecked Sendable {
256306
return error
257307
}
258308

309+
310+
func createMessage(imageData: Data) -> Components.Schemas.ChatCompletionRequestMessage {
311+
.ChatCompletionRequestUserMessage(
312+
.init(content: .case2([.ChatCompletionRequestMessageContentPartImage(
313+
.init(_type: .image_url,
314+
image_url:
315+
.init(url: "data:image/jpeg;base64,\(imageData.base64EncodedString())",
316+
detail: .auto)))]),
317+
role: .user))
318+
}
319+
320+
259321
}
260322

0 commit comments

Comments
 (0)