Skip to content

Commit d589b79

Browse files
authored
Make follow-up changes to Locked (#131)
* Move Locked to new Shared subdirectory * Rename Locked.access to Locked.withLock * Add concise documentation comments to Locked * Move StructuredGeneration to Shared directory * Adopt Locked in MLXLanguageModel * Adopt Locked in StructuredGeneration * Don't shadow inFlight in closure * Replace Important with Note * Don't shadow state in closure
1 parent 99679b3 commit d589b79

File tree

7 files changed

+91
-95
lines changed

7 files changed

+91
-95
lines changed

Sources/AnyLanguageModel/LanguageModelSession.swift

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ import Observation
55
public final class LanguageModelSession: @unchecked Sendable {
66
public var isResponding: Bool {
77
access(keyPath: \.isResponding)
8-
return state.access { $0.isResponding }
8+
return state.withLock { $0.isResponding }
99
}
1010

1111
public var transcript: Transcript {
1212
access(keyPath: \.transcript)
13-
return state.access { $0.transcript }
13+
return state.withLock { $0.transcript }
1414
}
1515

1616
@ObservationIgnored private let state: Locked<State>
@@ -103,13 +103,13 @@ public final class LanguageModelSession: @unchecked Sendable {
103103

104104
nonisolated private func beginResponding() {
105105
withMutation(keyPath: \.isResponding) {
106-
state.access { $0.beginResponding() }
106+
state.withLock { $0.beginResponding() }
107107
}
108108
}
109109

110110
nonisolated private func endResponding() {
111111
withMutation(keyPath: \.isResponding) {
112-
state.access { $0.endResponding() }
112+
state.withLock { $0.endResponding() }
113113
}
114114
}
115115

@@ -159,7 +159,7 @@ public final class LanguageModelSession: @unchecked Sendable {
159159
)
160160
)
161161
session.withMutation(keyPath: \.transcript) {
162-
session.state.access { $0.transcript.append(responseEntry) }
162+
session.state.withLock { $0.transcript.append(responseEntry) }
163163
}
164164
}
165165
} catch {
@@ -209,7 +209,7 @@ public final class LanguageModelSession: @unchecked Sendable {
209209
)
210210
)
211211
withMutation(keyPath: \.transcript) {
212-
state.access { $0.transcript.append(promptEntry) }
212+
state.withLock { $0.transcript.append(promptEntry) }
213213
}
214214

215215
let response = try await model.respond(
@@ -237,9 +237,9 @@ public final class LanguageModelSession: @unchecked Sendable {
237237

238238
// Add tool entries and response to transcript
239239
withMutation(keyPath: \.transcript) {
240-
state.access { state in
241-
state.transcript.append(contentsOf: response.transcriptEntries)
242-
state.transcript.append(responseEntry)
240+
state.withLock { lockedState in
241+
lockedState.transcript.append(contentsOf: response.transcriptEntries)
242+
lockedState.transcript.append(responseEntry)
243243
}
244244
}
245245

@@ -262,7 +262,7 @@ public final class LanguageModelSession: @unchecked Sendable {
262262
)
263263
)
264264
withMutation(keyPath: \.transcript) {
265-
state.access { $0.transcript.append(promptEntry) }
265+
state.withLock { $0.transcript.append(promptEntry) }
266266
}
267267

268268
return wrapStream(
@@ -558,7 +558,7 @@ extension LanguageModelSession {
558558
)
559559
)
560560
withMutation(keyPath: \.transcript) {
561-
state.access { $0.transcript.append(promptEntry) }
561+
state.withLock { $0.transcript.append(promptEntry) }
562562
}
563563

564564
// Extract text content for the Prompt parameter
@@ -589,9 +589,9 @@ extension LanguageModelSession {
589589

590590
// Add tool entries and response to transcript
591591
withMutation(keyPath: \.transcript) {
592-
state.access { state in
593-
state.transcript.append(contentsOf: response.transcriptEntries)
594-
state.transcript.append(responseEntry)
592+
state.withLock { lockedState in
593+
lockedState.transcript.append(contentsOf: response.transcriptEntries)
594+
lockedState.transcript.append(responseEntry)
595595
}
596596
}
597597

@@ -664,7 +664,7 @@ extension LanguageModelSession {
664664
)
665665
)
666666
withMutation(keyPath: \.transcript) {
667-
state.access { $0.transcript.append(promptEntry) }
667+
state.withLock { $0.transcript.append(promptEntry) }
668668
}
669669

670670
// Extract text content for the Prompt parameter

Sources/AnyLanguageModel/Locked.swift

Lines changed: 0 additions & 16 deletions
This file was deleted.

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ import Foundation
2727
/// Coordinates a bounded in-memory cache with structured, coalesced loading.
2828
private final class ModelContextCache {
2929
private let cache: NSCache<NSString, CachedContext>
30-
private let lock = NSLock()
31-
private var inFlight: [String: Task<CachedContext, Error>] = [:]
30+
private let inFlight = Locked<[String: Task<CachedContext, Error>]>([:])
3231

3332
/// Creates a cache with a count-based eviction limit.
3433
init(countLimit: Int) {
@@ -90,37 +89,31 @@ import Foundation
9089
}
9190

9291
private func inFlightTask(for key: String) -> Task<CachedContext, Error>? {
93-
lock.lock()
94-
defer { lock.unlock() }
95-
return inFlight[key]
92+
inFlight.withLock { $0[key] }
9693
}
9794

9895
private func setInFlight(_ task: Task<CachedContext, Error>, for key: String) {
99-
lock.lock()
100-
inFlight[key] = task
101-
lock.unlock()
96+
inFlight.withLock { $0[key] = task }
10297
}
10398

10499
private func clearInFlight(for key: String) {
105-
lock.lock()
106-
inFlight[key] = nil
107-
lock.unlock()
100+
inFlight.withLock { $0[key] = nil }
108101
}
109102

110103
private func removeInFlight(for key: String) -> Task<CachedContext, Error>? {
111-
lock.lock()
112-
defer { lock.unlock() }
113-
let task = inFlight[key]
114-
inFlight[key] = nil
115-
return task
104+
inFlight.withLock {
105+
let task = $0[key]
106+
$0[key] = nil
107+
return task
108+
}
116109
}
117110

118111
private func removeAllInFlight() -> [Task<CachedContext, Error>] {
119-
lock.lock()
120-
defer { lock.unlock() }
121-
let tasks = Array(inFlight.values)
122-
inFlight.removeAll()
123-
return tasks
112+
inFlight.withLock {
113+
let tasks = Array($0.values)
114+
$0.removeAll()
115+
return tasks
116+
}
124117
}
125118
}
126119

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import Foundation
2+
3+
/// Protects shared mutable state behind an `NSLock`.
4+
final class Locked<State> {
5+
private let lock = NSLock()
6+
private var state: State
7+
8+
/// Creates a locked container with the given initial state.
9+
init(_ state: State) {
10+
self.state = state
11+
}
12+
13+
/// Executes `body` while holding the lock.
14+
///
15+
/// - Parameter body: A closure that reads or mutates the protected state.
16+
/// - Returns: The value returned by `body`.
17+
/// - Throws: Rethrows any error from `body`.
18+
/// - Note: Keep critical sections small and synchronous.
19+
func withLock<T>(_ body: (inout State) throws -> T) rethrows -> T {
20+
try lock.withLock { try body(&self.state) }
21+
}
22+
}
23+
24+
extension Locked: @unchecked Sendable where State: Sendable {}

Sources/AnyLanguageModel/StructuredGeneration.swift renamed to Sources/AnyLanguageModel/Shared/StructuredGeneration.swift

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,14 @@ private final class StringTokenCache: @unchecked Sendable {
4545
let sampleTexts: [String]
4646
}
4747

48-
private var cache: [Key: Set<Int>] = [:]
49-
private let lock = NSLock()
48+
private let tokensByKey = Locked<[Key: Set<Int>]>([:])
5049

5150
func tokens(for key: Key) -> Set<Int>? {
52-
lock.lock()
53-
defer { lock.unlock() }
54-
return cache[key]
51+
tokensByKey.withLock { $0[key] }
5552
}
5653

5754
func store(_ tokens: Set<Int>, for key: Key) {
58-
lock.lock()
59-
cache[key] = tokens
60-
lock.unlock()
55+
tokensByKey.withLock { $0[key] = tokens }
6156
}
6257
}
6358

Tests/AnyLanguageModelTests/LockedTests.swift

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,32 @@ import Testing
33

44
@testable import AnyLanguageModel
55

6-
@Suite("Locked")
6+
@Suite("Locked Tests")
77
struct LockedTests {
88
@Test("Read access returns the initial value")
99
func readAccess() {
1010
let locked = Locked(42)
11-
let value = locked.access { $0 }
11+
let value = locked.withLock { $0 }
1212
#expect(value == 42)
1313
}
1414

1515
@Test("Write access mutates the state")
1616
func writeAccess() {
1717
let locked = Locked(0)
18-
locked.access { $0 = 99 }
19-
let value = locked.access { $0 }
18+
locked.withLock { $0 = 99 }
19+
let value = locked.withLock { $0 }
2020
#expect(value == 99)
2121
}
2222

2323
@Test("Access returns the value from the closure")
2424
func returnValue() {
2525
let locked = Locked("hello")
26-
let result = locked.access { state -> Int in
26+
let result = locked.withLock { state -> Int in
2727
state += " world"
2828
return state.count
2929
}
3030
#expect(result == 11)
31-
#expect(locked.access { $0 } == "hello world")
31+
#expect(locked.withLock { $0 } == "hello world")
3232
}
3333

3434
@Test("Access propagates thrown errors")
@@ -37,7 +37,7 @@ struct LockedTests {
3737

3838
let locked = Locked(0)
3939
#expect(throws: TestError.self) {
40-
try locked.access { _ in throw TestError() }
40+
try locked.withLock { _ in throw TestError() }
4141
}
4242
}
4343

@@ -50,14 +50,14 @@ struct LockedTests {
5050
}
5151

5252
let locked = Locked(State(name: "initial", count: 0, tags: []))
53-
locked.access { state in
53+
locked.withLock { state in
5454
state.name = "updated"
5555
state.count = 5
5656
state.tags.append("a")
5757
state.tags.append("b")
5858
}
5959

60-
let snapshot = locked.access { $0 }
60+
let snapshot = locked.withLock { $0 }
6161
#expect(snapshot.name == "updated")
6262
#expect(snapshot.count == 5)
6363
#expect(snapshot.tags == ["a", "b"])
@@ -70,11 +70,11 @@ struct LockedTests {
7070

7171
await withTaskGroup(of: Void.self) { group in
7272
for _ in 0 ..< iterations {
73-
group.addTask { locked.access { $0 += 1 } }
73+
group.addTask { locked.withLock { $0 += 1 } }
7474
}
7575
}
7676

77-
let finalValue = locked.access { $0 }
77+
let finalValue = locked.withLock { $0 }
7878
#expect(finalValue == iterations)
7979
}
8080

@@ -87,12 +87,12 @@ struct LockedTests {
8787
for i in 0 ..< iterations {
8888
let priority: TaskPriority = i.isMultiple(of: 2) ? .high : .background
8989
group.addTask(priority: priority) {
90-
locked.access { $0 += 1 }
90+
locked.withLock { $0 += 1 }
9191
}
9292
}
9393
}
9494

95-
let finalValue = locked.access { $0 }
95+
let finalValue = locked.withLock { $0 }
9696
#expect(finalValue == iterations)
9797
}
9898

@@ -103,11 +103,11 @@ struct LockedTests {
103103

104104
await withTaskGroup(of: Void.self) { group in
105105
for i in 0 ..< iterations {
106-
group.addTask { locked.access { $0.append(i) } }
106+
group.addTask { locked.withLock { $0.append(i) } }
107107
}
108108
}
109109

110-
let finalArray = locked.access { $0 }
110+
let finalArray = locked.withLock { $0 }
111111
#expect(finalArray.count == iterations)
112112
}
113113

@@ -119,13 +119,13 @@ struct LockedTests {
119119

120120
await withTaskGroup(of: Void.self) { group in
121121
for _ in 0 ..< iterations {
122-
group.addTask { lockedA.access { $0 += 1 } }
123-
group.addTask { lockedB.access { $0 += 1 } }
122+
group.addTask { lockedA.withLock { $0 += 1 } }
123+
group.addTask { lockedB.withLock { $0 += 1 } }
124124
}
125125
}
126126

127-
#expect(lockedA.access { $0 } == iterations)
128-
#expect(lockedB.access { $0 } == iterations)
127+
#expect(lockedA.withLock { $0 } == iterations)
128+
#expect(lockedB.withLock { $0 } == iterations)
129129
}
130130

131131
@Test("Can wrap a non-Sendable type")
@@ -136,17 +136,17 @@ struct LockedTests {
136136
}
137137

138138
let locked = Locked(Box(10))
139-
locked.access { $0.value += 5 }
140-
let result = locked.access { $0.value }
139+
locked.withLock { $0.value += 5 }
140+
let result = locked.withLock { $0.value }
141141
#expect(result == 15)
142142
}
143143

144144
@Test("Copies share the same underlying storage")
145145
func copySharesStorage() {
146146
let original = Locked(0)
147147
let copy = original
148-
original.access { $0 = 42 }
149-
let value = copy.access { $0 }
148+
original.withLock { $0 = 42 }
149+
let value = copy.withLock { $0 }
150150
#expect(value == 42)
151151
}
152152
}

0 commit comments

Comments
 (0)