Skip to content

Commit 41a745f

Browse files
committed
Add Cancel Streaming response via Task Cancellation
1 parent 66a547f commit 41a745f

File tree

4 files changed

+75
-43
lines changed

4 files changed

+75
-43
lines changed

Shared/ChatGPTAPI.swift

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class ChatGPTAPI: @unchecked Sendable {
7777
urlRequest.httpBody = try jsonBody(text: text)
7878

7979
let (result, response) = try await urlSession.bytes(for: urlRequest)
80+
try Task.checkCancellation()
8081

8182
guard let httpResponse = response as? HTTPURLResponse else {
8283
throw "Invalid response"
@@ -85,6 +86,7 @@ class ChatGPTAPI: @unchecked Sendable {
8586
guard 200...299 ~= httpResponse.statusCode else {
8687
var errorText = ""
8788
for try await line in result.lines {
89+
try Task.checkCancellation()
8890
errorText += line
8991
}
9092

@@ -95,26 +97,21 @@ class ChatGPTAPI: @unchecked Sendable {
9597
throw "Bad Response: \(httpResponse.statusCode), \(errorText)"
9698
}
9799

98-
return AsyncThrowingStream<String, Error> { continuation in
99-
Task(priority: .userInitiated) { [weak self] in
100-
guard let self else { return }
101-
do {
102-
var responseText = ""
103-
for try await line in result.lines {
104-
if line.hasPrefix("data: "),
105-
let data = line.dropFirst(6).data(using: .utf8),
106-
let response = try? self.jsonDecoder.decode(StreamCompletionResponse.self, from: data),
107-
let text = response.choices.first?.delta.content {
108-
responseText += text
109-
continuation.yield(text)
110-
}
111-
}
112-
self.appendToHistoryList(userText: text, responseText: responseText)
113-
continuation.finish()
114-
} catch {
115-
continuation.finish(throwing: error)
100+
var responseText = ""
101+
return AsyncThrowingStream { [weak self] in
102+
guard let self else { return nil }
103+
for try await line in result.lines {
104+
try Task.checkCancellation()
105+
if line.hasPrefix("data: "),
106+
let data = line.dropFirst(6).data(using: .utf8),
107+
let response = try? self.jsonDecoder.decode(StreamCompletionResponse.self, from: data),
108+
let text = response.choices.first?.delta.content {
109+
responseText += text
110+
return text
116111
}
117112
}
113+
self.appendToHistoryList(userText: text, responseText: responseText)
114+
return nil
118115
}
119116
}
120117

@@ -123,7 +120,7 @@ class ChatGPTAPI: @unchecked Sendable {
123120
urlRequest.httpBody = try jsonBody(text: text, stream: false)
124121

125122
let (data, response) = try await urlSession.data(for: urlRequest)
126-
123+
try Task.checkCancellation()
127124
guard let httpResponse = response as? HTTPURLResponse else {
128125
throw "Invalid response"
129126
}

Shared/ContentView.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,18 @@ struct ContentView: View {
7373
.disabled(vm.isInteractingWithChatGPT)
7474

7575
if vm.isInteractingWithChatGPT {
76+
#if os(iOS)
77+
Button {
78+
vm.cancelStreamingResponse()
79+
} label: {
80+
Image(systemName: "stop.circle.fill")
81+
.font(.system(size: 30))
82+
.symbolRenderingMode(.multicolor)
83+
.foregroundColor(.red)
84+
}
85+
#else
7686
DotLoadingView().frame(width: 60, height: 30)
87+
#endif
7788
} else {
7889
Button {
7990
Task { @MainActor in

Shared/MessageRow.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct MessageRow: Identifiable {
3333
var isInteractingWithChatGPT: Bool
3434

3535
let sendImage: String
36-
let send: MessageRowType
36+
var send: MessageRowType
3737
var sendText: String {
3838
send.text
3939
}

Shared/ViewModel.swift

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class ViewModel: ObservableObject {
1414
@Published var isInteractingWithChatGPT = false
1515
@Published var messages: [MessageRow] = []
1616
@Published var inputMessage: String = ""
17+
var task: Task<Void, Never>?
1718

1819
#if !os(watchOS)
1920
private var synthesizer: AVSpeechSynthesizer?
@@ -32,11 +33,15 @@ class ViewModel: ObservableObject {
3233

3334
@MainActor
3435
func sendTapped() async {
35-
let text = inputMessage
36-
inputMessage = ""
3736
#if os(iOS)
38-
await sendAttributed(text: text)
37+
self.task = Task {
38+
let text = inputMessage
39+
inputMessage = ""
40+
await sendAttributed(text: text)
41+
}
3942
#else
43+
let text = inputMessage
44+
inputMessage = ""
4045
await send(text: text)
4146
#endif
4247
}
@@ -52,47 +57,61 @@ class ViewModel: ObservableObject {
5257

5358
@MainActor
5459
func retry(message: MessageRow) async {
60+
#if os(iOS)
61+
self.task = Task {
62+
guard let index = messages.firstIndex(where: { $0.id == message.id }) else {
63+
return
64+
}
65+
self.messages.remove(at: index)
66+
await sendAttributed(text: message.sendText)
67+
}
68+
#else
5569
guard let index = messages.firstIndex(where: { $0.id == message.id }) else {
5670
return
5771
}
5872
self.messages.remove(at: index)
59-
#if os(iOS)
60-
await sendAttributed(text: message.sendText)
61-
#else
6273
await send(text: message.sendText)
6374
#endif
6475
}
6576

77+
func cancelStreamingResponse() {
78+
self.task?.cancel()
79+
self.task = nil
80+
}
81+
6682
#if os(iOS)
6783
@MainActor
6884
private func sendAttributed(text: String) async {
6985
isInteractingWithChatGPT = true
70-
71-
let parsingTask = ResponseParsingTask()
72-
let attributedSend = await parsingTask.parse(text: text)
73-
7486
var streamText = ""
87+
7588
var messageRow = MessageRow(
7689
isInteractingWithChatGPT: true,
7790
sendImage: "profile",
78-
send: .attributed(attributedSend),
91+
send: .rawText(text),
7992
responseImage: "openai",
8093
responseError: nil)
81-
82-
self.messages.append(messageRow)
83-
84-
let parserThresholdTextCount = 64
85-
var currentTextCount = 0
86-
var currentOutput: AttributedOutput?
87-
94+
8895
do {
96+
let parsingTask = ResponseParsingTask()
97+
let attributedSend = await parsingTask.parse(text: text)
98+
try Task.checkCancellation()
99+
messageRow.send = .attributed(attributedSend)
100+
101+
self.messages.append(messageRow)
102+
103+
let parserThresholdTextCount = 64
104+
var currentTextCount = 0
105+
var currentOutput: AttributedOutput?
106+
89107
let stream = try await api.sendMessageStream(text: text)
90108
for try await text in stream {
91109
streamText += text
92110
currentTextCount += text.count
93111

94112
if currentTextCount >= parserThresholdTextCount || text.contains("```") {
95113
currentOutput = await parsingTask.parse(text: streamText)
114+
try Task.checkCancellation()
96115
currentTextCount = 0
97116
}
98117

@@ -115,17 +134,22 @@ class ViewModel: ObservableObject {
115134
}
116135

117136
self.messages[self.messages.count - 1] = messageRow
137+
if let currentString = currentOutput?.string, currentString != streamText {
138+
let output = await parsingTask.parse(text: streamText)
139+
try Task.checkCancellation()
140+
messageRow.response = .attributed(output)
141+
}
118142
}
143+
} catch is CancellationError {
144+
messageRow.responseError = "The response was cancelled"
119145
} catch {
120146
messageRow.responseError = error.localizedDescription
121-
messageRow.response = .rawText(streamText)
122147
}
123148

124-
if let currentString = currentOutput?.string, currentString != streamText {
125-
let output = await parsingTask.parse(text: streamText)
126-
messageRow.response = .attributed(output)
149+
if messageRow.response == nil {
150+
messageRow.response = .rawText(streamText)
127151
}
128-
152+
129153
messageRow.isInteractingWithChatGPT = false
130154
self.messages[self.messages.count - 1] = messageRow
131155
isInteractingWithChatGPT = false

0 commit comments

Comments
 (0)