Skip to content

Commit 0f7493b

Browse files
Alex-WenggJarbasAl
andauthored
feat: Support Parakeet-TDT-CTC-110M hybrid model (#433)
## Summary Adds support for NVIDIA's Parakeet-TDT-CTC-110M hybrid model with fused preprocessor+encoder architecture. Based on the work by @JarbasAl in #383. ## Key Changes ### Model Architecture - **Fused preprocessor+encoder**: No separate Encoder.mlmodelc file - **Smaller dimensions**: encoderHidden=512, vocabSize=1024, single LSTM layer - **Array-format vocabulary**: vocab.json instead of dict format - **BlankId**: 1024 (same as v2) ### Code Modifications - **AsrModels**: Optional encoder support, fused frontend loading, array vocab handling - **AsrManager**: Version-aware decoder state shapes, fused frontend availability checking - **AsrTranscription**: Skip encoder step when preprocessor output is fused - **TdtDecoderState**: Parameterized LSTM layer count - **TdtDecoderV3**: Use config.encoderHiddenSize instead of auto-detection - **EncoderFrameView**: Accept explicit hidden size parameter - **TranscribeCommand**: New `--model-version tdt-ctc-110m` and `--model-dir` flags - **ModelNames**: parakeetTdtCtc110m repo reference ### CLI Usage ```bash swift run fluidaudiocli transcribe audio.wav --model-version tdt-ctc-110m swift run fluidaudiocli transcribe audio.wav --model-version tdt-ctc-110m --model-dir /path/to/custom/models ``` ## Testing - [ ] iOS compatibility testing (per concerns in #383) - [ ] Benchmark performance documentation - [ ] Verify fused model behavior on both macOS and iOS ## Related - Closes #383 - Model repo: [FluidInference/parakeet-tdt-ctc-110m-coreml](https://huggingface.co/FluidInference/parakeet-tdt-ctc-110m-coreml) <img width="642" height="1389" alt="IMG_5033" src="https://github.com/user-attachments/assets/a9105cf7-552b-4573-acfb-2a089bf52820" /><!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/fluidinference/fluidaudio/pull/433" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end --> --------- Co-authored-by: miro <jarbasai@mailfence.com>
1 parent 0346057 commit 0f7493b

File tree

17 files changed

+1005
-83
lines changed

17 files changed

+1005
-83
lines changed

Documentation/ASR/TDT-CTC-110M.md

Lines changed: 473 additions & 0 deletions
Large diffs are not rendered by default.

Documentation/Models.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ A guide to each CoreML model pipeline in FluidAudio.
1010
|-------|-------------|---------|
1111
| **Parakeet TDT v2** | Batch speech-to-text, English only (0.6B params). TDT architecture. | First ASR model added. |
1212
| **Parakeet TDT v3** | Batch speech-to-text, 25 European languages (0.6B params). Default ASR model. | Released after v2 to add multilingual support. |
13+
| **Parakeet TDT-CTC-110M** | Hybrid TDT-CTC batch model (110M params). 3.01% WER on LibriSpeech test-clean. 96.5x RTFx on M2 Mac. Fused preprocessor+encoder for reduced memory footprint. iOS compatible. | Smaller, faster alternative to v3 with competitive accuracy. |
1314

1415
TDT models process audio in chunks (~15s with overlap) as batch operations. Fast enough for dictation-style workflows. Not suitable for word-by-word live captions.
1516

@@ -63,6 +64,7 @@ Models we converted and tested but haven't shipped yet — either still in devel
6364
|-------|-----------------|
6465
| Parakeet TDT v3 | [FluidInference/parakeet-tdt-0.6b-v3-coreml](https://huggingface.co/FluidInference/parakeet-tdt-0.6b-v3-coreml) |
6566
| Parakeet TDT v2 | [FluidInference/parakeet-tdt-0.6b-v2-coreml](https://huggingface.co/FluidInference/parakeet-tdt-0.6b-v2-coreml) |
67+
| Parakeet TDT-CTC-110M | [FluidInference/parakeet-tdt-ctc-110m-coreml](https://huggingface.co/FluidInference/parakeet-tdt-ctc-110m-coreml) |
6668
| Parakeet CTC 110M | [FluidInference/parakeet-ctc-110m-coreml](https://huggingface.co/FluidInference/parakeet-ctc-110m-coreml) |
6769
| Parakeet CTC 0.6B | [FluidInference/parakeet-ctc-0.6b-coreml](https://huggingface.co/FluidInference/parakeet-ctc-0.6b-coreml) |
6870
| Parakeet EOU | [FluidInference/parakeet-realtime-eou-120m-coreml](https://huggingface.co/FluidInference/parakeet-realtime-eou-120m-coreml) |

Sources/FluidAudio/ASR/AsrManager.swift

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,16 @@ public actor AsrManager {
2020
internal var jointModel: MLModel?
2121

2222
/// The AsrModels instance if initialized with models
23-
private var asrModels: AsrModels?
23+
internal var asrModels: AsrModels?
2424

2525
internal let progressEmitter = ProgressEmitter()
2626

27+
/// Get the number of decoder layers for the current model.
28+
/// Returns 2 if models not loaded (v2/v3 default, tdtCtc110m uses 1).
29+
internal func getDecoderLayers() -> Int {
30+
return asrModels?.version.decoderLayers ?? 2
31+
}
32+
2733
/// Token duration optimization model
2834

2935
/// Cached vocabulary loaded once during initialization
@@ -88,14 +94,16 @@ public actor AsrManager {
8894
}
8995

9096
public var isAvailable: Bool {
91-
let baseModelsReady = encoderModel != nil && decoderModel != nil && jointModel != nil
92-
guard baseModelsReady else { return false }
97+
let decoderReady = decoderModel != nil && jointModel != nil
98+
guard decoderReady else { return false }
9399

94100
if asrModels?.usesSplitFrontend == true {
101+
// Split frontend: need both preprocessor and encoder
102+
return preprocessorModel != nil && encoderModel != nil
103+
} else {
104+
// Fused frontend: preprocessor contains encoder
95105
return preprocessorModel != nil
96106
}
97-
98-
return true
99107
}
100108

101109
/// Initialize ASR Manager with pre-loaded models
@@ -110,7 +118,10 @@ public actor AsrManager {
110118
self.jointModel = models.joint
111119
self.vocabulary = models.vocabulary
112120

113-
logger.info("Token duration optimization model loaded successfully")
121+
// Recreate decoder states with the correct layer count for this model version
122+
let layers = models.version.decoderLayers
123+
self.microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
124+
self.systemDecoderState = TdtDecoderState.make(decoderLayers: layers)
114125

115126
logger.info("AsrManager initialized successfully with provided models")
116127
}
@@ -293,19 +304,24 @@ public actor AsrManager {
293304
}
294305

295306
public func resetState() {
296-
microphoneDecoderState = TdtDecoderState.make()
297-
systemDecoderState = TdtDecoderState.make()
307+
// Use model's decoder layer count, or 2 if models not loaded (v2/v3 default)
308+
let layers = asrModels?.version.decoderLayers ?? 2
309+
microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
310+
systemDecoderState = TdtDecoderState.make(decoderLayers: layers)
298311
Task { await sharedMLArrayCache.clear() }
299312
}
300313

301314
public func cleanup() {
315+
// Capture layer count before releasing models, fallback to 2 (v2/v3 default)
316+
let layers = asrModels?.version.decoderLayers ?? 2
317+
asrModels = nil
302318
preprocessorModel = nil
303319
encoderModel = nil
304320
decoderModel = nil
305321
jointModel = nil
306322
// Reset decoder states using fresh allocations for deterministic behavior
307-
microphoneDecoderState = TdtDecoderState.make()
308-
systemDecoderState = TdtDecoderState.make()
323+
microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
324+
systemDecoderState = TdtDecoderState.make(decoderLayers: layers)
309325
// Release vocabulary boosting resources
310326
disableVocabularyBoosting()
311327
Task { await sharedMLArrayCache.clear() }
@@ -326,9 +342,25 @@ public actor AsrManager {
326342
guard let models = asrModels, let decoder_ = decoderModel, let joint = jointModel else {
327343
throw ASRError.notInitialized
328344
}
345+
346+
// Adapt config's encoderHiddenSize to match the loaded model version
347+
// (e.g. default config uses 1024 but tdtCtc110m needs 512)
348+
let adaptedConfig: ASRConfig
349+
if config.encoderHiddenSize != models.version.encoderHiddenSize {
350+
adaptedConfig = ASRConfig(
351+
sampleRate: config.sampleRate,
352+
tdtConfig: config.tdtConfig,
353+
encoderHiddenSize: models.version.encoderHiddenSize,
354+
streamingEnabled: config.streamingEnabled,
355+
streamingThreshold: config.streamingThreshold
356+
)
357+
} else {
358+
adaptedConfig = config
359+
}
360+
329361
switch models.version {
330-
case .v2:
331-
let decoder = TdtDecoderV2(config: config)
362+
case .v2, .tdtCtc110m:
363+
let decoder = TdtDecoderV2(config: adaptedConfig)
332364
return try await decoder.decodeWithTimings(
333365
encoderOutput: encoderOutput,
334366
encoderSequenceLength: encoderSequenceLength,
@@ -341,7 +373,7 @@ public actor AsrManager {
341373
globalFrameOffset: globalFrameOffset
342374
)
343375
case .v3:
344-
let decoder = TdtDecoderV3(config: config)
376+
let decoder = TdtDecoderV3(config: adaptedConfig)
345377
return try await decoder.decodeWithTimings(
346378
encoderOutput: encoderOutput,
347379
encoderSequenceLength: encoderSequenceLength,

Sources/FluidAudio/ASR/AsrModels.swift

Lines changed: 97 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,46 @@ import OSLog
66
public enum AsrModelVersion: Sendable {
77
case v2
88
case v3
9+
/// 110M parameter hybrid TDT-CTC model with fused preprocessor+encoder
10+
case tdtCtc110m
911

1012
var repo: Repo {
1113
switch self {
1214
case .v2: return .parakeetV2
1315
case .v3: return .parakeet
16+
case .tdtCtc110m: return .parakeetTdtCtc110m
17+
}
18+
}
19+
20+
/// Whether this model version uses a fused preprocessor+encoder (no separate Encoder model)
21+
public var hasFusedEncoder: Bool {
22+
switch self {
23+
case .tdtCtc110m: return true
24+
default: return false
25+
}
26+
}
27+
28+
/// Encoder hidden dimension for this model version
29+
public var encoderHiddenSize: Int {
30+
switch self {
31+
case .tdtCtc110m: return 512
32+
default: return 1024
33+
}
34+
}
35+
36+
/// Blank token ID for this model version
37+
public var blankId: Int {
38+
switch self {
39+
case .v2, .tdtCtc110m: return 1024
40+
case .v3: return 8192
41+
}
42+
}
43+
44+
/// Number of LSTM layers in the decoder prediction network
45+
public var decoderLayers: Int {
46+
switch self {
47+
case .tdtCtc110m: return 1
48+
default: return 2
1449
}
1550
}
1651
}
@@ -20,7 +55,8 @@ public struct AsrModels: Sendable {
2055
/// Required model names for ASR
2156
public static let requiredModelNames = ModelNames.ASR.requiredModels
2257

23-
public let encoder: MLModel
58+
/// Separate encoder model (nil for fused models like tdtCtc110m where preprocessor includes encoder)
59+
public let encoder: MLModel?
2460
public let preprocessor: MLModel
2561
public let decoder: MLModel
2662
public let joint: MLModel
@@ -31,7 +67,7 @@ public struct AsrModels: Sendable {
3167
private static let logger = AppLogger(category: "AsrModels")
3268

3369
public init(
34-
encoder: MLModel,
70+
encoder: MLModel?,
3571
preprocessor: MLModel,
3672
decoder: MLModel,
3773
joint: MLModel,
@@ -48,8 +84,9 @@ public struct AsrModels: Sendable {
4884
self.version = version
4985
}
5086

87+
/// Whether this model uses a separate preprocessor and encoder (true for 0.6B, false for 110m fused)
5188
public var usesSplitFrontend: Bool {
52-
true
89+
!version.hasFusedEncoder
5390
}
5491
}
5592

@@ -60,7 +97,15 @@ extension AsrModels {
6097
let computeUnits: MLComputeUnits
6198
}
6299

63-
private static func createModelSpecs(using config: MLModelConfiguration) -> [ModelSpec] {
100+
private static func createModelSpecs(
101+
using config: MLModelConfiguration, version: AsrModelVersion
102+
) -> [ModelSpec] {
103+
if version.hasFusedEncoder {
104+
// Fused preprocessor+encoder runs on ANE (it contains the conformer encoder)
105+
return [
106+
ModelSpec(fileName: Names.preprocessorFile, computeUnits: config.computeUnits)
107+
]
108+
}
64109
return [
65110
// Preprocessor ops map to CPU-only across all platforms. XCode profiling shows
66111
// that 100% of the the operations map to the CPU anyways.
@@ -78,7 +123,7 @@ extension AsrModels {
78123

79124
private static func inferredVersion(from directory: URL) -> AsrModelVersion? {
80125
let directoryPath = directory.path.lowercased()
81-
let knownVersions: [AsrModelVersion] = [.v2, .v3]
126+
let knownVersions: [AsrModelVersion] = [.tdtCtc110m, .v2, .v3]
82127

83128
for version in knownVersions {
84129
if directoryPath.contains(version.repo.folderName.lowercased()) {
@@ -118,7 +163,7 @@ extension AsrModels {
118163

119164
let parentDirectory = directory.deletingLastPathComponent()
120165
// Load preprocessor and encoder first; decoder and joint are loaded below as well.
121-
let specs = createModelSpecs(using: config)
166+
let specs = createModelSpecs(using: config, version: version)
122167

123168
var loadedModels: [String: MLModel] = [:]
124169

@@ -138,10 +183,13 @@ extension AsrModels {
138183
}
139184
}
140185

141-
guard let preprocessorModel = loadedModels[Names.preprocessorFile],
142-
let encoderModel = loadedModels[Names.encoderFile]
143-
else {
144-
throw AsrModelsError.loadingFailed("Failed to load preprocessor or encoder model")
186+
guard let preprocessorModel = loadedModels[Names.preprocessorFile] else {
187+
throw AsrModelsError.loadingFailed("Failed to load preprocessor model")
188+
}
189+
let encoderModel = loadedModels[Names.encoderFile] // nil for fused models
190+
191+
if !version.hasFusedEncoder && encoderModel == nil {
192+
throw AsrModelsError.loadingFailed("Failed to load encoder model (required for split frontend)")
145193
}
146194

147195
// Load decoder and joint as well
@@ -185,18 +233,30 @@ extension AsrModels {
185233

186234
do {
187235
let data = try Data(contentsOf: vocabPath)
188-
let jsonDict = try JSONSerialization.jsonObject(with: data) as? [String: String] ?? [:]
236+
let json = try JSONSerialization.jsonObject(with: data)
189237

190238
var vocabulary: [Int: String] = [:]
191239

192-
for (key, value) in jsonDict {
193-
if let tokenId = Int(key) {
194-
vocabulary[tokenId] = value
240+
if let jsonArray = json as? [String] {
241+
// Array format (110m hybrid): index = token ID
242+
for (index, token) in jsonArray.enumerated() {
243+
vocabulary[index] = token
244+
}
245+
} else if let jsonDict = json as? [String: String] {
246+
// Dictionary format (0.6B v2/v3): key = token ID string
247+
for (key, value) in jsonDict {
248+
if let tokenId = Int(key) {
249+
vocabulary[tokenId] = value
250+
}
195251
}
252+
} else {
253+
throw AsrModelsError.loadingFailed("Vocabulary file has unexpected format")
196254
}
197255

198256
logger.info("Loaded vocabulary with \(vocabulary.count) tokens from \(vocabPath.path)")
199257
return vocabulary
258+
} catch let error as AsrModelsError {
259+
throw error
200260
} catch {
201261
logger.error(
202262
"Failed to load or parse vocabulary file at \(vocabPath.path): \(error.localizedDescription)"
@@ -324,13 +384,23 @@ extension AsrModels {
324384

325385
let defaultUnits = defaultConfiguration().computeUnits
326386

327-
let specs: [DownloadSpec] = [
328-
// Preprocessor ops map to CPU-only across all platforms.
329-
DownloadSpec(fileName: Names.preprocessorFile, computeUnits: .cpuOnly),
330-
DownloadSpec(fileName: Names.encoderFile, computeUnits: defaultUnits),
331-
DownloadSpec(fileName: Names.decoderFile, computeUnits: defaultUnits),
332-
DownloadSpec(fileName: Names.jointFile, computeUnits: defaultUnits),
333-
]
387+
let specs: [DownloadSpec]
388+
if version.hasFusedEncoder {
389+
specs = [
390+
// Fused preprocessor+encoder runs on ANE
391+
DownloadSpec(fileName: Names.preprocessorFile, computeUnits: defaultUnits),
392+
DownloadSpec(fileName: Names.decoderFile, computeUnits: defaultUnits),
393+
DownloadSpec(fileName: Names.jointFile, computeUnits: defaultUnits),
394+
]
395+
} else {
396+
specs = [
397+
// Preprocessor ops map to CPU-only across all platforms.
398+
DownloadSpec(fileName: Names.preprocessorFile, computeUnits: .cpuOnly),
399+
DownloadSpec(fileName: Names.encoderFile, computeUnits: defaultUnits),
400+
DownloadSpec(fileName: Names.decoderFile, computeUnits: defaultUnits),
401+
DownloadSpec(fileName: Names.jointFile, computeUnits: defaultUnits),
402+
]
403+
}
334404

335405
for spec in specs {
336406
_ = try await DownloadUtils.loadModels(
@@ -365,7 +435,8 @@ extension AsrModels {
365435

366436
public static func modelsExist(at directory: URL, version: AsrModelVersion) -> Bool {
367437
let fileManager = FileManager.default
368-
let requiredFiles = ModelNames.ASR.requiredModels
438+
let requiredFiles =
439+
version.hasFusedEncoder ? ModelNames.ASR.requiredModelsFused : ModelNames.ASR.requiredModels
369440

370441
// Check in the DownloadUtils repo structure
371442
let repoPath = repoPath(from: directory, version: version)
@@ -397,12 +468,14 @@ extension AsrModels {
397468
let config = MLModelConfiguration()
398469
config.computeUnits = .cpuOnly
399470

400-
let modelsToValidate = [
471+
var modelsToValidate = [
401472
("Preprocessor", ModelNames.ASR.preprocessorFile),
402-
("Encoder", ModelNames.ASR.encoderFile),
403473
("Decoder", ModelNames.ASR.decoderFile),
404474
("Joint", ModelNames.ASR.jointFile),
405475
]
476+
if !version.hasFusedEncoder {
477+
modelsToValidate.insert(("Encoder", ModelNames.ASR.encoderFile), at: 1)
478+
}
406479

407480
for (name, fileName) in modelsToValidate {
408481
let modelPath = repoPath.appendingPathComponent(fileName)

0 commit comments

Comments
 (0)