@@ -15,8 +15,7 @@ public final class AsrManager {
1515 internal let config : ASRConfig
1616 private let audioConverter : AudioConverter = AudioConverter ( )
1717
18- internal var melspectrogramModel : MLModel ?
19- internal var encoderModel : MLModel ?
18+ internal var melEncoderModel : MLModel ?
2019 internal var decoderModel : MLModel ?
2120 internal var jointModel : MLModel ?
2221
@@ -43,36 +42,30 @@ public final class AsrManager {
4342 AsrModels . optimizedPredictionOptions ( )
4443 } ( )
4544
46- // Persistent feature providers for zero-copy model chaining
47- private var zeroCopyProviders : [ String : ZeroCopyFeatureProvider ] = [ : ]
48-
4945 public init ( config: ASRConfig = . default) {
5046 self . config = config
5147
52- // Initialize decoder states with fallback
53- do {
54- self . microphoneDecoderState = try TdtDecoderState ( )
55- self . systemDecoderState = try TdtDecoderState ( )
56- } catch {
57- logger. warning ( " Failed to create ANE-aligned decoder states, using standard allocation " )
58- // This should rarely happen, but if it does, we'll create them during first use
59- self . microphoneDecoderState = TdtDecoderState ( fallback: true )
60- self . systemDecoderState = TdtDecoderState ( fallback: true )
61- }
48+ self . microphoneDecoderState = TdtDecoderState . make ( )
49+ self . systemDecoderState = TdtDecoderState . make ( )
6250
6351 // Pre-warm caches if possible
6452 Task {
6553 await sharedMLArrayCache. prewarm ( shapes: [
66- ( [ 1 , 240000 ] , . float32) ,
67- ( [ 1 ] , . int32) ,
68- ( [ 2 , 1 , 640 ] , . float32) ,
54+ ( [ NSNumber ( value: 1 ) , NSNumber ( value: 240_000 ) ] , . float32) ,
55+ ( [ NSNumber ( value: 1 ) ] , . int32) ,
56+ (
57+ [
58+ NSNumber ( value: 2 ) ,
59+ NSNumber ( value: 1 ) ,
60+ NSNumber ( value: ASRConstants . decoderHiddenSize) ,
61+ ] , . float32
62+ ) ,
6963 ] )
7064 }
7165 }
7266
7367 public var isAvailable : Bool {
74- return melspectrogramModel != nil && encoderModel != nil && decoderModel != nil
75- && jointModel != nil
68+ return melEncoderModel != nil && decoderModel != nil && jointModel != nil
7669 }
7770
7871 /// Initialize ASR Manager with pre-loaded models
@@ -81,8 +74,7 @@ public final class AsrManager {
8174 logger. info ( " Initializing AsrManager with provided models " )
8275
8376 self . asrModels = models
84- self . melspectrogramModel = models. melspectrogram
85- self . encoderModel = models. encoder
77+ self . melEncoderModel = models. melEncoder
8678 self . decoderModel = models. decoder
8779 self . jointModel = models. joint
8880 self . vocabulary = models. vocabulary
@@ -112,7 +104,7 @@ public final class AsrManager {
112104 return array
113105 }
114106
115- func prepareMelSpectrogramInput (
107+ func prepareMelEncoderInput (
116108 _ audioSamples: [ Float ] , actualLength: Int ? = nil
117109 ) async throws
118110 -> MLFeatureProvider
@@ -141,37 +133,6 @@ public final class AsrManager {
141133 ] )
142134 }
143135
144- func prepareEncoderInput( _ melspectrogramOutput: MLFeatureProvider ) throws -> MLFeatureProvider {
145- // Zero-copy: chain mel-spectrogram outputs directly to encoder inputs
146- if let provider = ZeroCopyFeatureProvider . chain (
147- from: melspectrogramOutput,
148- outputName: " melspectrogram " ,
149- to: " audio_signal "
150- ) {
151- // Also need to chain the length
152- if let melLength = melspectrogramOutput. featureValue ( for: " melspectrogram_length " ) {
153- let features = [
154- " audio_signal " : provider. featureValue ( for: " audio_signal " ) !,
155- " length " : melLength,
156- ]
157- return ZeroCopyFeatureProvider ( features: features)
158- }
159- }
160-
161- // Fallback to copying if zero-copy fails
162- let melspectrogram = try extractFeatureValue (
163- from: melspectrogramOutput, key: " melspectrogram " ,
164- errorMessage: " Invalid mel-spectrogram output " )
165- let melspectrogramLength = try extractFeatureValue (
166- from: melspectrogramOutput, key: " melspectrogram_length " ,
167- errorMessage: " Invalid mel-spectrogram length output " )
168-
169- return try createFeatureProvider ( features: [
170- ( " audio_signal " , melspectrogram) ,
171- ( " length " , melspectrogramLength) ,
172- ] )
173- }
174-
175136 private func prepareDecoderInput(
176137 hiddenState: MLMultiArray ,
177138 cellState: MLMultiArray
@@ -181,7 +142,7 @@ public final class AsrManager {
181142
182143 return try createFeatureProvider ( features: [
183144 ( " targets " , targetArray) ,
184- ( " target_lengths " , targetLengthArray) ,
145+ ( " target_length " , targetLengthArray) ,
185146 ( " h_in " , hiddenState) ,
186147 ( " c_in " , cellState) ,
187148 ] )
@@ -225,21 +186,18 @@ public final class AsrManager {
225186 }
226187
227188 private func loadAllModels(
228- melspectrogramPath: URL ,
229- encoderPath: URL ,
189+ melEncoderPath: URL ,
230190 decoderPath: URL ,
231191 jointPath: URL ,
232192 configuration: MLModelConfiguration
233- ) async throws -> ( melspectrogram: MLModel , encoder: MLModel , decoder: MLModel , joint: MLModel ) {
234- async let melspectrogram = loadModel (
235- path: melspectrogramPath, name: " mel-spectrogram " , configuration: configuration)
236- async let encoder = loadModel (
237- path: encoderPath, name: " encoder " , configuration: configuration)
193+ ) async throws -> ( melEncoder: MLModel , decoder: MLModel , joint: MLModel ) {
194+ async let melEncoder = loadModel (
195+ path: melEncoderPath, name: " mel-encoder " , configuration: configuration)
238196 async let decoder = loadModel (
239197 path: decoderPath, name: " decoder " , configuration: configuration)
240198 async let joint = loadModel ( path: jointPath, name: " joint " , configuration: configuration)
241199
242- return try await ( melspectrogram , encoder , decoder, joint)
200+ return try await ( melEncoder , decoder, joint)
243201 }
244202
245203 private static func getDefaultModelsDirectory( ) -> URL {
@@ -255,18 +213,17 @@ public final class AsrManager {
255213 }
256214
257215 public func resetState( ) {
258- microphoneDecoderState = TdtDecoderState ( fallback : true )
259- systemDecoderState = TdtDecoderState ( fallback : true )
216+ microphoneDecoderState = TdtDecoderState . make ( )
217+ systemDecoderState = TdtDecoderState . make ( )
260218 }
261219
262220 public func cleanup( ) {
263- melspectrogramModel = nil
264- encoderModel = nil
221+ melEncoderModel = nil
265222 decoderModel = nil
266223 jointModel = nil
267- // Reset decoder states - use fallback initializer that won't throw
268- microphoneDecoderState = TdtDecoderState ( fallback : true )
269- systemDecoderState = TdtDecoderState ( fallback : true )
224+ // Reset decoder states using fresh allocations for deterministic behavior
225+ microphoneDecoderState = TdtDecoderState . make ( )
226+ systemDecoderState = TdtDecoderState . make ( )
270227 logger. info ( " AsrManager resources cleaned up " )
271228 }
272229
0 commit comments