Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
473 changes: 473 additions & 0 deletions Documentation/ASR/TDT-CTC-110M.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Documentation/Models.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ A guide to each CoreML model pipeline in FluidAudio.
|-------|-------------|---------|
| **Parakeet TDT v2** | Batch speech-to-text, English only (0.6B params). TDT architecture. | First ASR model added. |
| **Parakeet TDT v3** | Batch speech-to-text, 25 European languages (0.6B params). Default ASR model. | Released after v2 to add multilingual support. |
| **Parakeet TDT-CTC-110M** | Hybrid TDT-CTC batch model (110M params). 3.01% WER on LibriSpeech test-clean. 96.5x RTFx on M2 Mac. Fused preprocessor+encoder for reduced memory footprint. iOS compatible. | Smaller, faster alternative to v3 with competitive accuracy. |

TDT models process audio in chunks (~15s with overlap) as batch operations. Fast enough for dictation-style workflows. Not suitable for word-by-word live captions.

Expand Down Expand Up @@ -63,6 +64,7 @@ Models we converted and tested but haven't shipped yet — either still in devel
|-------|-----------------|
| Parakeet TDT v3 | [FluidInference/parakeet-tdt-0.6b-v3-coreml](https://huggingface.co/FluidInference/parakeet-tdt-0.6b-v3-coreml) |
| Parakeet TDT v2 | [FluidInference/parakeet-tdt-0.6b-v2-coreml](https://huggingface.co/FluidInference/parakeet-tdt-0.6b-v2-coreml) |
| Parakeet TDT-CTC-110M | [FluidInference/parakeet-tdt-ctc-110m-coreml](https://huggingface.co/FluidInference/parakeet-tdt-ctc-110m-coreml) |
| Parakeet CTC 110M | [FluidInference/parakeet-ctc-110m-coreml](https://huggingface.co/FluidInference/parakeet-ctc-110m-coreml) |
| Parakeet CTC 0.6B | [FluidInference/parakeet-ctc-0.6b-coreml](https://huggingface.co/FluidInference/parakeet-ctc-0.6b-coreml) |
| Parakeet EOU | [FluidInference/parakeet-realtime-eou-120m-coreml](https://huggingface.co/FluidInference/parakeet-realtime-eou-120m-coreml) |
Expand Down
58 changes: 45 additions & 13 deletions Sources/FluidAudio/ASR/AsrManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ public actor AsrManager {
internal var jointModel: MLModel?

/// The AsrModels instance if initialized with models
private var asrModels: AsrModels?
internal var asrModels: AsrModels?

internal let progressEmitter = ProgressEmitter()

/// Get the number of decoder layers for the current model.
/// Returns 2 if models not loaded (v2/v3 default, tdtCtc110m uses 1).
internal func getDecoderLayers() -> Int {
return asrModels?.version.decoderLayers ?? 2
}

/// Token duration optimization model

/// Cached vocabulary loaded once during initialization
Expand Down Expand Up @@ -88,14 +94,16 @@ public actor AsrManager {
}

public var isAvailable: Bool {
let baseModelsReady = encoderModel != nil && decoderModel != nil && jointModel != nil
guard baseModelsReady else { return false }
let decoderReady = decoderModel != nil && jointModel != nil
guard decoderReady else { return false }

if asrModels?.usesSplitFrontend == true {
// Split frontend: need both preprocessor and encoder
return preprocessorModel != nil && encoderModel != nil
} else {
// Fused frontend: preprocessor contains encoder
return preprocessorModel != nil
}

return true
}

/// Initialize ASR Manager with pre-loaded models
Expand All @@ -110,7 +118,10 @@ public actor AsrManager {
self.jointModel = models.joint
self.vocabulary = models.vocabulary

logger.info("Token duration optimization model loaded successfully")
// Recreate decoder states with the correct layer count for this model version
let layers = models.version.decoderLayers
self.microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
self.systemDecoderState = TdtDecoderState.make(decoderLayers: layers)

logger.info("AsrManager initialized successfully with provided models")
}
Expand Down Expand Up @@ -293,19 +304,24 @@ public actor AsrManager {
}

public func resetState() {
microphoneDecoderState = TdtDecoderState.make()
systemDecoderState = TdtDecoderState.make()
// Use model's decoder layer count, or 2 if models not loaded (v2/v3 default)
let layers = asrModels?.version.decoderLayers ?? 2
microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
systemDecoderState = TdtDecoderState.make(decoderLayers: layers)
Task { await sharedMLArrayCache.clear() }
}

public func cleanup() {
// Capture layer count before releasing models, fallback to 2 (v2/v3 default)
let layers = asrModels?.version.decoderLayers ?? 2
asrModels = nil
preprocessorModel = nil
encoderModel = nil
decoderModel = nil
jointModel = nil
// Reset decoder states using fresh allocations for deterministic behavior
microphoneDecoderState = TdtDecoderState.make()
systemDecoderState = TdtDecoderState.make()
microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
systemDecoderState = TdtDecoderState.make(decoderLayers: layers)
// Release vocabulary boosting resources
disableVocabularyBoosting()
Task { await sharedMLArrayCache.clear() }
Expand All @@ -326,9 +342,25 @@ public actor AsrManager {
guard let models = asrModels, let decoder_ = decoderModel, let joint = jointModel else {
throw ASRError.notInitialized
}

// Adapt config's encoderHiddenSize to match the loaded model version
// (e.g. default config uses 1024 but tdtCtc110m needs 512)
let adaptedConfig: ASRConfig
if config.encoderHiddenSize != models.version.encoderHiddenSize {
adaptedConfig = ASRConfig(
sampleRate: config.sampleRate,
tdtConfig: config.tdtConfig,
encoderHiddenSize: models.version.encoderHiddenSize,
streamingEnabled: config.streamingEnabled,
streamingThreshold: config.streamingThreshold
)
} else {
adaptedConfig = config
}

switch models.version {
case .v2:
let decoder = TdtDecoderV2(config: config)
case .v2, .tdtCtc110m:
let decoder = TdtDecoderV2(config: adaptedConfig)
return try await decoder.decodeWithTimings(
encoderOutput: encoderOutput,
encoderSequenceLength: encoderSequenceLength,
Expand All @@ -341,7 +373,7 @@ public actor AsrManager {
globalFrameOffset: globalFrameOffset
)
case .v3:
let decoder = TdtDecoderV3(config: config)
let decoder = TdtDecoderV3(config: adaptedConfig)
return try await decoder.decodeWithTimings(
encoderOutput: encoderOutput,
encoderSequenceLength: encoderSequenceLength,
Expand Down
121 changes: 97 additions & 24 deletions Sources/FluidAudio/ASR/AsrModels.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,46 @@ import OSLog
public enum AsrModelVersion: Sendable {
case v2
case v3
/// 110M parameter hybrid TDT-CTC model with fused preprocessor+encoder
case tdtCtc110m

var repo: Repo {
switch self {
case .v2: return .parakeetV2
case .v3: return .parakeet
case .tdtCtc110m: return .parakeetTdtCtc110m
}
}

/// Whether this model version uses a fused preprocessor+encoder (no separate Encoder model)
public var hasFusedEncoder: Bool {
switch self {
case .tdtCtc110m: return true
default: return false
}
}

/// Encoder hidden dimension for this model version
public var encoderHiddenSize: Int {
switch self {
case .tdtCtc110m: return 512
default: return 1024
}
}

/// Blank token ID for this model version
public var blankId: Int {
switch self {
case .v2, .tdtCtc110m: return 1024
case .v3: return 8192
}
}

/// Number of LSTM layers in the decoder prediction network
public var decoderLayers: Int {
switch self {
case .tdtCtc110m: return 1
default: return 2
}
}
}
Expand All @@ -20,7 +55,8 @@ public struct AsrModels: Sendable {
/// Required model names for ASR
public static let requiredModelNames = ModelNames.ASR.requiredModels

public let encoder: MLModel
/// Separate encoder model (nil for fused models like tdtCtc110m where preprocessor includes encoder)
public let encoder: MLModel?
public let preprocessor: MLModel
public let decoder: MLModel
public let joint: MLModel
Expand All @@ -31,7 +67,7 @@ public struct AsrModels: Sendable {
private static let logger = AppLogger(category: "AsrModels")

public init(
encoder: MLModel,
encoder: MLModel?,
preprocessor: MLModel,
decoder: MLModel,
joint: MLModel,
Expand All @@ -48,8 +84,9 @@ public struct AsrModels: Sendable {
self.version = version
}

/// Whether this model uses a separate preprocessor and encoder (true for 0.6B, false for 110m fused)
public var usesSplitFrontend: Bool {
true
!version.hasFusedEncoder
}
}

Expand All @@ -60,7 +97,15 @@ extension AsrModels {
let computeUnits: MLComputeUnits
}

private static func createModelSpecs(using config: MLModelConfiguration) -> [ModelSpec] {
private static func createModelSpecs(
using config: MLModelConfiguration, version: AsrModelVersion
) -> [ModelSpec] {
if version.hasFusedEncoder {
// Fused preprocessor+encoder runs on ANE (it contains the conformer encoder)
return [
ModelSpec(fileName: Names.preprocessorFile, computeUnits: config.computeUnits)
]
}
return [
// Preprocessor ops map to CPU-only across all platforms. XCode profiling shows
// that 100% of the the operations map to the CPU anyways.
Expand All @@ -78,7 +123,7 @@ extension AsrModels {

private static func inferredVersion(from directory: URL) -> AsrModelVersion? {
let directoryPath = directory.path.lowercased()
let knownVersions: [AsrModelVersion] = [.v2, .v3]
let knownVersions: [AsrModelVersion] = [.tdtCtc110m, .v2, .v3]

for version in knownVersions {
if directoryPath.contains(version.repo.folderName.lowercased()) {
Expand Down Expand Up @@ -118,7 +163,7 @@ extension AsrModels {

let parentDirectory = directory.deletingLastPathComponent()
// Load preprocessor and encoder first; decoder and joint are loaded below as well.
let specs = createModelSpecs(using: config)
let specs = createModelSpecs(using: config, version: version)

var loadedModels: [String: MLModel] = [:]

Expand All @@ -138,10 +183,13 @@ extension AsrModels {
}
}

guard let preprocessorModel = loadedModels[Names.preprocessorFile],
let encoderModel = loadedModels[Names.encoderFile]
else {
throw AsrModelsError.loadingFailed("Failed to load preprocessor or encoder model")
guard let preprocessorModel = loadedModels[Names.preprocessorFile] else {
throw AsrModelsError.loadingFailed("Failed to load preprocessor model")
}
let encoderModel = loadedModels[Names.encoderFile] // nil for fused models

if !version.hasFusedEncoder && encoderModel == nil {
throw AsrModelsError.loadingFailed("Failed to load encoder model (required for split frontend)")
}

// Load decoder and joint as well
Expand Down Expand Up @@ -185,18 +233,30 @@ extension AsrModels {

do {
let data = try Data(contentsOf: vocabPath)
let jsonDict = try JSONSerialization.jsonObject(with: data) as? [String: String] ?? [:]
let json = try JSONSerialization.jsonObject(with: data)

var vocabulary: [Int: String] = [:]

for (key, value) in jsonDict {
if let tokenId = Int(key) {
vocabulary[tokenId] = value
if let jsonArray = json as? [String] {
// Array format (110m hybrid): index = token ID
for (index, token) in jsonArray.enumerated() {
vocabulary[index] = token
}
} else if let jsonDict = json as? [String: String] {
// Dictionary format (0.6B v2/v3): key = token ID string
for (key, value) in jsonDict {
if let tokenId = Int(key) {
vocabulary[tokenId] = value
}
}
} else {
throw AsrModelsError.loadingFailed("Vocabulary file has unexpected format")
}

logger.info("Loaded vocabulary with \(vocabulary.count) tokens from \(vocabPath.path)")
return vocabulary
} catch let error as AsrModelsError {
throw error
} catch {
logger.error(
"Failed to load or parse vocabulary file at \(vocabPath.path): \(error.localizedDescription)"
Expand Down Expand Up @@ -324,13 +384,23 @@ extension AsrModels {

let defaultUnits = defaultConfiguration().computeUnits

let specs: [DownloadSpec] = [
// Preprocessor ops map to CPU-only across all platforms.
DownloadSpec(fileName: Names.preprocessorFile, computeUnits: .cpuOnly),
DownloadSpec(fileName: Names.encoderFile, computeUnits: defaultUnits),
DownloadSpec(fileName: Names.decoderFile, computeUnits: defaultUnits),
DownloadSpec(fileName: Names.jointFile, computeUnits: defaultUnits),
]
let specs: [DownloadSpec]
if version.hasFusedEncoder {
specs = [
// Fused preprocessor+encoder runs on ANE
DownloadSpec(fileName: Names.preprocessorFile, computeUnits: defaultUnits),
DownloadSpec(fileName: Names.decoderFile, computeUnits: defaultUnits),
DownloadSpec(fileName: Names.jointFile, computeUnits: defaultUnits),
]
} else {
specs = [
// Preprocessor ops map to CPU-only across all platforms.
DownloadSpec(fileName: Names.preprocessorFile, computeUnits: .cpuOnly),
DownloadSpec(fileName: Names.encoderFile, computeUnits: defaultUnits),
DownloadSpec(fileName: Names.decoderFile, computeUnits: defaultUnits),
DownloadSpec(fileName: Names.jointFile, computeUnits: defaultUnits),
]
}

for spec in specs {
_ = try await DownloadUtils.loadModels(
Expand Down Expand Up @@ -365,7 +435,8 @@ extension AsrModels {

public static func modelsExist(at directory: URL, version: AsrModelVersion) -> Bool {
let fileManager = FileManager.default
let requiredFiles = ModelNames.ASR.requiredModels
let requiredFiles =
version.hasFusedEncoder ? ModelNames.ASR.requiredModelsFused : ModelNames.ASR.requiredModels

// Check in the DownloadUtils repo structure
let repoPath = repoPath(from: directory, version: version)
Expand Down Expand Up @@ -397,12 +468,14 @@ extension AsrModels {
let config = MLModelConfiguration()
config.computeUnits = .cpuOnly

let modelsToValidate = [
var modelsToValidate = [
("Preprocessor", ModelNames.ASR.preprocessorFile),
("Encoder", ModelNames.ASR.encoderFile),
("Decoder", ModelNames.ASR.decoderFile),
("Joint", ModelNames.ASR.jointFile),
]
if !version.hasFusedEncoder {
modelsToValidate.insert(("Encoder", ModelNames.ASR.encoderFile), at: 1)
}

for (name, fileName) in modelsToValidate {
let modelPath = repoPath.appendingPathComponent(fileName)
Expand Down
Loading
Loading