diff --git a/Sources/FluidAudio/ASR/Parakeet/AsrTypes.swift b/Sources/FluidAudio/ASR/Parakeet/AsrTypes.swift index c4dcf2950..3b17629b4 100644 --- a/Sources/FluidAudio/ASR/Parakeet/AsrTypes.swift +++ b/Sources/FluidAudio/ASR/Parakeet/AsrTypes.swift @@ -9,6 +9,10 @@ public struct ASRConfig: Sendable { /// Encoder hidden dimension (1024 for 0.6B, 512 for 110m) public let encoderHiddenSize: Int + /// Number of long-form chunks to transcribe concurrently. + /// Applies only to stateless chunked transcription paths. + public let parallelChunkConcurrency: Int + /// Enable streaming mode for large files to reduce memory usage. /// When enabled, files larger than `streamingThreshold` samples will be processed /// using streaming to maintain constant memory usage. @@ -25,12 +29,14 @@ public struct ASRConfig: Sendable { sampleRate: Int = 16000, tdtConfig: TdtConfig = .default, encoderHiddenSize: Int = ASRConstants.encoderHiddenSize, + parallelChunkConcurrency: Int = 4, streamingEnabled: Bool = true, streamingThreshold: Int = 480_000 ) { self.sampleRate = sampleRate self.tdtConfig = tdtConfig self.encoderHiddenSize = encoderHiddenSize + self.parallelChunkConcurrency = max(1, parallelChunkConcurrency) self.streamingEnabled = streamingEnabled self.streamingThreshold = streamingThreshold } diff --git a/Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/AsrManager.swift b/Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/AsrManager.swift index 3e7d7aba8..51cb16099 100644 --- a/Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/AsrManager.swift +++ b/Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/AsrManager.swift @@ -25,6 +25,10 @@ public actor AsrManager { asrModels?.version.decoderLayers ?? 2 } + internal var parallelChunkConcurrency: Int { + config.parallelChunkConcurrency + } + /// Cached vocabulary loaded once during initialization internal var vocabulary: [Int: String] = [:] #if DEBUG @@ -39,9 +43,18 @@ public actor AsrManager { AsrModels.optimizedPredictionOptions() }() - public init(config: ASRConfig = .default) { + public init(config: ASRConfig = .default, models: AsrModels? = nil) { self.config = config + if let models { + self.asrModels = models + self.preprocessorModel = models.preprocessor + self.encoderModel = models.encoder + self.decoderModel = models.decoder + self.jointModel = models.joint + self.vocabulary = models.vocabulary + } + // Pre-warm caches if possible Task { await sharedMLArrayCache.prewarm(shapes: [ @@ -63,6 +76,11 @@ public actor AsrManager { } } + internal func makeWorkerClone() -> AsrManager? { + guard let models = asrModels else { return nil } + return AsrManager(config: config, models: models) + } + /// Returns the current transcription progress stream for offline long audio (>240,000 samples / ~15s). /// Only one session is supported at a time. public var transcriptionProgressStream: AsyncThrowingStream { @@ -200,6 +218,7 @@ public actor AsrManager { sampleRate: config.sampleRate, tdtConfig: config.tdtConfig, encoderHiddenSize: models.version.encoderHiddenSize, + parallelChunkConcurrency: config.parallelChunkConcurrency, streamingEnabled: config.streamingEnabled, streamingThreshold: config.streamingThreshold ) diff --git a/Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/ChunkProcessor.swift b/Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/ChunkProcessor.swift index 1d553befc..8f3e1ac34 100644 --- a/Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/ChunkProcessor.swift +++ b/Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/ChunkProcessor.swift @@ -6,6 +6,11 @@ struct ChunkProcessor { private let logger = AppLogger(category: "ChunkProcessor") private typealias TokenWindow = (token: Int, timestamp: Int, confidence: Float, duration: Int) + private struct TaskResult: Sendable { + let index: Int + let tokens: [TokenWindow] + let workerIndex: Int + } private struct IndexedToken { let index: Int let token: TokenWindow @@ -58,74 +63,121 @@ struct ChunkProcessor { startTime: Date, progressHandler: ((Double) async -> Void)? = nil ) async throws -> ASRResult { - var chunkOutputs: [[TokenWindow]] = [] - + let requestedConcurrency = max(1, await manager.parallelChunkConcurrency) + let workers = await makeWorkerPool(using: manager, count: requestedConcurrency) ?? [manager] + let decoderLayers = await manager.decoderLayerCount + let maxModelSamples = self.maxModelSamples + + var chunkOutputs: [[TokenWindow]?] = [] + var availableWorkers = Array(workers.indices) + var inFlight = 0 var chunkStart = 0 var chunkIndex = 0 - var chunkDecoderState = TdtDecoderState.make( - decoderLayers: await manager.decoderLayerCount - ) - - while chunkStart < totalSamples { - try Task.checkCancellation() - let candidateEnd = chunkStart + chunkSamples - let isLastChunk = candidateEnd >= totalSamples - let chunkEnd = isLastChunk ? totalSamples : candidateEnd - - if chunkEnd <= chunkStart { - break - } - - chunkDecoderState.reset() - - // For chunks after the first, prepend context samples from the overlap region. - // This provides left context for the mel spectrogram STFT window and encoder convolutions. - let contextSamples = chunkIndex > 0 ? melContextSamples : 0 - let contextStart = chunkStart - contextSamples - let chunkLengthWithContext = chunkEnd - contextStart - let chunkSamplesArray = try readSamples(offset: contextStart, count: chunkLengthWithContext) - - let (windowTokens, windowTimestamps, windowConfidences, windowDurations) = try await transcribeChunk( - samples: chunkSamplesArray, - contextSamples: contextSamples, - chunkStart: chunkStart, - isLastChunk: isLastChunk, - using: manager, - decoderState: &chunkDecoderState - ) - - // Combine tokens, timestamps, and confidences into aligned tuples - guard windowTokens.count == windowTimestamps.count && windowTokens.count == windowConfidences.count else { - throw ASRError.processingFailed("Token, timestamp, and confidence arrays are misaligned") - } - // Default to 0 per token if durations array is misaligned (shouldn't happen in practice) - let durations = - windowDurations.count == windowTokens.count - ? windowDurations : Array(repeating: 0, count: windowTokens.count) - - let windowData: [TokenWindow] = zip( - zip(zip(windowTokens, windowTimestamps), windowConfidences), durations - ).map { - (token: $0.0.0.0, timestamp: $0.0.0.1, confidence: $0.0.1, duration: $0.1) - } - chunkOutputs.append(windowData) - - chunkIndex += 1 + func collectNextResult( + _ group: inout ThrowingTaskGroup + ) async throws { + guard inFlight > 0 else { return } + guard let finished = try await group.next() else { return } + chunkOutputs[finished.index] = finished.tokens + availableWorkers.append(finished.workerIndex) + inFlight -= 1 + } - if isLastChunk { - break + try await withThrowingTaskGroup(of: TaskResult.self) { group in + while chunkStart < totalSamples { + try Task.checkCancellation() + let candidateEnd = chunkStart + chunkSamples + let isLastChunk = candidateEnd >= totalSamples + let chunkEnd = isLastChunk ? totalSamples : candidateEnd + + if chunkEnd <= chunkStart { + break + } + + // For chunks after the first, prepend context samples from the overlap region. + // This provides left context for the mel spectrogram STFT window and encoder convolutions. + let contextSamples = chunkIndex > 0 ? melContextSamples : 0 + let contextStart = chunkStart - contextSamples + let chunkLengthWithContext = chunkEnd - contextStart + let chunkSamplesArray = try readSamples(offset: contextStart, count: chunkLengthWithContext) + + if availableWorkers.isEmpty { + try await collectNextResult(&group) + } + if availableWorkers.isEmpty { + availableWorkers.append(0) + } + + let workerIndex = availableWorkers.removeFirst() + let worker = workers[workerIndex] + let index = chunkIndex + let chunkStartOffset = chunkStart + chunkOutputs.append(nil) + + group.addTask { + var decoderState = TdtDecoderState.make(decoderLayers: decoderLayers) + decoderState.reset() + + let (windowTokens, windowTimestamps, windowConfidences, windowDurations) = + try await Self + .transcribeChunk( + samples: chunkSamplesArray, + contextSamples: contextSamples, + chunkStart: chunkStartOffset, + isLastChunk: isLastChunk, + using: worker, + decoderState: &decoderState, + maxModelSamples: maxModelSamples + ) + + guard + windowTokens.count == windowTimestamps.count + && windowTokens.count == windowConfidences.count + else { + throw ASRError.processingFailed("Token, timestamp, and confidence arrays are misaligned") + } + + let durations = + windowDurations.count == windowTokens.count + ? windowDurations : Array(repeating: 0, count: windowTokens.count) + + let windowData: [TokenWindow] = zip( + zip(zip(windowTokens, windowTimestamps), windowConfidences), durations + ).map { + (token: $0.0.0.0, timestamp: $0.0.0.1, confidence: $0.0.1, duration: $0.1) + } + + return TaskResult(index: index, tokens: windowData, workerIndex: workerIndex) + } + inFlight += 1 + chunkIndex += 1 + + if let progressHandler, !isLastChunk { + let progress = min(1.0, max(0.0, Double(chunkEnd) / Double(totalSamples))) + await progressHandler(progress) + } + + if isLastChunk { + break + } + + chunkStart += strideSamples + + if availableWorkers.isEmpty && inFlight > 0 { + try await collectNextResult(&group) + } } - if let progressHandler { - let progress = min(1.0, max(0.0, Double(chunkEnd) / Double(totalSamples))) - await progressHandler(progress) + while inFlight > 0 { + try Task.checkCancellation() + try await collectNextResult(&group) } - - chunkStart += strideSamples } - guard var mergedTokens = chunkOutputs.first else { + let orderedChunkOutputs = chunkOutputs.compactMap { $0 } + + guard var mergedTokens = orderedChunkOutputs.first else { return await manager.processTranscriptionResult( tokenIds: [], timestamps: [], @@ -136,8 +188,8 @@ struct ChunkProcessor { ) } - if chunkOutputs.count > 1 { - for chunk in chunkOutputs.dropFirst() { + if orderedChunkOutputs.count > 1 { + for chunk in orderedChunkOutputs.dropFirst() { mergedTokens = mergeChunks(mergedTokens, chunk) } } @@ -162,6 +214,22 @@ struct ChunkProcessor { ) } + private func makeWorkerPool(using manager: AsrManager, count: Int) async -> [AsrManager]? { + guard count > 0 else { return nil } + var workers: [AsrManager] = [manager] + if count == 1 { + return workers + } + for _ in 1.. [Float] { var buffer = [Float](repeating: 0, count: count) try buffer.withUnsafeMutableBufferPointer { pointer in @@ -170,13 +238,14 @@ struct ChunkProcessor { return buffer } - private func transcribeChunk( + private static func transcribeChunk( samples: [Float], contextSamples: Int, chunkStart: Int, isLastChunk: Bool, using manager: AsrManager, - decoderState: inout TdtDecoderState + decoderState: inout TdtDecoderState, + maxModelSamples: Int ) async throws -> (tokens: [Int], timestamps: [Int], confidences: [Float], durations: [Int]) { guard !samples.isEmpty else { return ([], [], [], []) } diff --git a/Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/Decoder/TdtDecoderV2.swift b/Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/Decoder/TdtDecoderV2.swift index 7037db7d2..5cdec30e8 100644 --- a/Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/Decoder/TdtDecoderV2.swift +++ b/Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/Decoder/TdtDecoderV2.swift @@ -69,7 +69,10 @@ internal struct TdtDecoderV2 { return ASRConfig( sampleRate: config.sampleRate, tdtConfig: adaptedTdt, - encoderHiddenSize: config.encoderHiddenSize + encoderHiddenSize: config.encoderHiddenSize, + parallelChunkConcurrency: config.parallelChunkConcurrency, + streamingEnabled: config.streamingEnabled, + streamingThreshold: config.streamingThreshold ) } } diff --git a/Tests/FluidAudioTests/ASR/Parakeet/ASRConfigTests.swift b/Tests/FluidAudioTests/ASR/Parakeet/ASRConfigTests.swift new file mode 100644 index 000000000..204304d1a --- /dev/null +++ b/Tests/FluidAudioTests/ASR/Parakeet/ASRConfigTests.swift @@ -0,0 +1,19 @@ +import XCTest + +@testable import FluidAudio + +final class ASRConfigTests: XCTestCase { + + func testDefaultParallelChunkConcurrency() { + XCTAssertEqual(ASRConfig.default.parallelChunkConcurrency, 4) + } + + func testParallelChunkConcurrencyClampsToAtLeastOne() { + XCTAssertEqual(ASRConfig(parallelChunkConcurrency: 0).parallelChunkConcurrency, 1) + XCTAssertEqual(ASRConfig(parallelChunkConcurrency: -3).parallelChunkConcurrency, 1) + } + + func testParallelChunkConcurrencyPreservesExplicitValue() { + XCTAssertEqual(ASRConfig(parallelChunkConcurrency: 6).parallelChunkConcurrency, 6) + } +}