Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Sources/FluidAudio/ASR/Parakeet/AsrTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ public struct ASRConfig: Sendable {
/// Encoder hidden dimension (1024 for 0.6B, 512 for 110m)
public let encoderHiddenSize: Int

/// Number of long-form chunks to transcribe concurrently.
/// Applies only to stateless chunked transcription paths.
public let parallelChunkConcurrency: Int

/// Enable streaming mode for large files to reduce memory usage.
/// When enabled, files larger than `streamingThreshold` samples will be processed
/// using streaming to maintain constant memory usage.
Expand All @@ -25,12 +29,14 @@ public struct ASRConfig: Sendable {
sampleRate: Int = 16000,
tdtConfig: TdtConfig = .default,
encoderHiddenSize: Int = ASRConstants.encoderHiddenSize,
parallelChunkConcurrency: Int = 4,
streamingEnabled: Bool = true,
streamingThreshold: Int = 480_000
) {
self.sampleRate = sampleRate
self.tdtConfig = tdtConfig
self.encoderHiddenSize = encoderHiddenSize
self.parallelChunkConcurrency = max(1, parallelChunkConcurrency)
self.streamingEnabled = streamingEnabled
self.streamingThreshold = streamingThreshold
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ public actor AsrManager {
asrModels?.version.decoderLayers ?? 2
}

internal var parallelChunkConcurrency: Int {
config.parallelChunkConcurrency
}

/// Cached vocabulary loaded once during initialization
internal var vocabulary: [Int: String] = [:]
#if DEBUG
Expand All @@ -39,9 +43,18 @@ public actor AsrManager {
AsrModels.optimizedPredictionOptions()
}()

public init(config: ASRConfig = .default) {
public init(config: ASRConfig = .default, models: AsrModels? = nil) {
self.config = config

if let models {
self.asrModels = models
self.preprocessorModel = models.preprocessor
self.encoderModel = models.encoder
self.decoderModel = models.decoder
self.jointModel = models.joint
self.vocabulary = models.vocabulary
}

// Pre-warm caches if possible
Task {
await sharedMLArrayCache.prewarm(shapes: [
Expand All @@ -63,6 +76,11 @@ public actor AsrManager {
}
}

internal func makeWorkerClone() -> AsrManager? {
guard let models = asrModels else { return nil }
return AsrManager(config: config, models: models)
}

/// Returns the current transcription progress stream for offline long audio (>240,000 samples / ~15s).
/// Only one session is supported at a time.
public var transcriptionProgressStream: AsyncThrowingStream<Double, Error> {
Expand Down Expand Up @@ -200,6 +218,7 @@ public actor AsrManager {
sampleRate: config.sampleRate,
tdtConfig: config.tdtConfig,
encoderHiddenSize: models.version.encoderHiddenSize,
parallelChunkConcurrency: config.parallelChunkConcurrency,
streamingEnabled: config.streamingEnabled,
streamingThreshold: config.streamingThreshold
)
Expand Down
195 changes: 132 additions & 63 deletions Sources/FluidAudio/ASR/Parakeet/SlidingWindow/TDT/ChunkProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ struct ChunkProcessor {

private let logger = AppLogger(category: "ChunkProcessor")
private typealias TokenWindow = (token: Int, timestamp: Int, confidence: Float, duration: Int)
private struct TaskResult: Sendable {
let index: Int
let tokens: [TokenWindow]
let workerIndex: Int
}
private struct IndexedToken {
let index: Int
let token: TokenWindow
Expand Down Expand Up @@ -58,74 +63,121 @@ struct ChunkProcessor {
startTime: Date,
progressHandler: ((Double) async -> Void)? = nil
) async throws -> ASRResult {
var chunkOutputs: [[TokenWindow]] = []

let requestedConcurrency = max(1, await manager.parallelChunkConcurrency)
let workers = await makeWorkerPool(using: manager, count: requestedConcurrency) ?? [manager]
let decoderLayers = await manager.decoderLayerCount
let maxModelSamples = self.maxModelSamples

var chunkOutputs: [[TokenWindow]?] = []
var availableWorkers = Array(workers.indices)
var inFlight = 0
var chunkStart = 0
var chunkIndex = 0
var chunkDecoderState = TdtDecoderState.make(
decoderLayers: await manager.decoderLayerCount
)

while chunkStart < totalSamples {
try Task.checkCancellation()
let candidateEnd = chunkStart + chunkSamples
let isLastChunk = candidateEnd >= totalSamples
let chunkEnd = isLastChunk ? totalSamples : candidateEnd

if chunkEnd <= chunkStart {
break
}

chunkDecoderState.reset()

// For chunks after the first, prepend context samples from the overlap region.
// This provides left context for the mel spectrogram STFT window and encoder convolutions.
let contextSamples = chunkIndex > 0 ? melContextSamples : 0
let contextStart = chunkStart - contextSamples
let chunkLengthWithContext = chunkEnd - contextStart
let chunkSamplesArray = try readSamples(offset: contextStart, count: chunkLengthWithContext)

let (windowTokens, windowTimestamps, windowConfidences, windowDurations) = try await transcribeChunk(
samples: chunkSamplesArray,
contextSamples: contextSamples,
chunkStart: chunkStart,
isLastChunk: isLastChunk,
using: manager,
decoderState: &chunkDecoderState
)

// Combine tokens, timestamps, and confidences into aligned tuples
guard windowTokens.count == windowTimestamps.count && windowTokens.count == windowConfidences.count else {
throw ASRError.processingFailed("Token, timestamp, and confidence arrays are misaligned")
}

// Default to 0 per token if durations array is misaligned (shouldn't happen in practice)
let durations =
windowDurations.count == windowTokens.count
? windowDurations : Array(repeating: 0, count: windowTokens.count)

let windowData: [TokenWindow] = zip(
zip(zip(windowTokens, windowTimestamps), windowConfidences), durations
).map {
(token: $0.0.0.0, timestamp: $0.0.0.1, confidence: $0.0.1, duration: $0.1)
}
chunkOutputs.append(windowData)

chunkIndex += 1
func collectNextResult(
_ group: inout ThrowingTaskGroup<TaskResult, Error>
) async throws {
guard inFlight > 0 else { return }
guard let finished = try await group.next() else { return }
chunkOutputs[finished.index] = finished.tokens
availableWorkers.append(finished.workerIndex)
inFlight -= 1
}

if isLastChunk {
break
try await withThrowingTaskGroup(of: TaskResult.self) { group in
while chunkStart < totalSamples {
try Task.checkCancellation()
let candidateEnd = chunkStart + chunkSamples
let isLastChunk = candidateEnd >= totalSamples
let chunkEnd = isLastChunk ? totalSamples : candidateEnd

if chunkEnd <= chunkStart {
break
}

// For chunks after the first, prepend context samples from the overlap region.
// This provides left context for the mel spectrogram STFT window and encoder convolutions.
let contextSamples = chunkIndex > 0 ? melContextSamples : 0
let contextStart = chunkStart - contextSamples
let chunkLengthWithContext = chunkEnd - contextStart
let chunkSamplesArray = try readSamples(offset: contextStart, count: chunkLengthWithContext)

if availableWorkers.isEmpty {
try await collectNextResult(&group)
}
if availableWorkers.isEmpty {
availableWorkers.append(0)
}

let workerIndex = availableWorkers.removeFirst()
let worker = workers[workerIndex]
let index = chunkIndex
let chunkStartOffset = chunkStart
chunkOutputs.append(nil)

group.addTask {
var decoderState = TdtDecoderState.make(decoderLayers: decoderLayers)
decoderState.reset()

let (windowTokens, windowTimestamps, windowConfidences, windowDurations) =
try await Self
.transcribeChunk(
samples: chunkSamplesArray,
contextSamples: contextSamples,
chunkStart: chunkStartOffset,
isLastChunk: isLastChunk,
using: worker,
decoderState: &decoderState,
maxModelSamples: maxModelSamples
)

guard
windowTokens.count == windowTimestamps.count
&& windowTokens.count == windowConfidences.count
else {
throw ASRError.processingFailed("Token, timestamp, and confidence arrays are misaligned")
}

let durations =
windowDurations.count == windowTokens.count
? windowDurations : Array(repeating: 0, count: windowTokens.count)

let windowData: [TokenWindow] = zip(
zip(zip(windowTokens, windowTimestamps), windowConfidences), durations
).map {
(token: $0.0.0.0, timestamp: $0.0.0.1, confidence: $0.0.1, duration: $0.1)
}

return TaskResult(index: index, tokens: windowData, workerIndex: workerIndex)
}
inFlight += 1
chunkIndex += 1

if let progressHandler, !isLastChunk {
let progress = min(1.0, max(0.0, Double(chunkEnd) / Double(totalSamples)))
await progressHandler(progress)
}

if isLastChunk {
break
}

chunkStart += strideSamples

if availableWorkers.isEmpty && inFlight > 0 {
try await collectNextResult(&group)
}
}

if let progressHandler {
let progress = min(1.0, max(0.0, Double(chunkEnd) / Double(totalSamples)))
await progressHandler(progress)
while inFlight > 0 {
try Task.checkCancellation()
try await collectNextResult(&group)
}

chunkStart += strideSamples
}

guard var mergedTokens = chunkOutputs.first else {
let orderedChunkOutputs = chunkOutputs.compactMap { $0 }

guard var mergedTokens = orderedChunkOutputs.first else {
return await manager.processTranscriptionResult(
tokenIds: [],
timestamps: [],
Expand All @@ -136,8 +188,8 @@ struct ChunkProcessor {
)
}

if chunkOutputs.count > 1 {
for chunk in chunkOutputs.dropFirst() {
if orderedChunkOutputs.count > 1 {
for chunk in orderedChunkOutputs.dropFirst() {
mergedTokens = mergeChunks(mergedTokens, chunk)
}
}
Expand All @@ -162,6 +214,22 @@ struct ChunkProcessor {
)
}

private func makeWorkerPool(using manager: AsrManager, count: Int) async -> [AsrManager]? {
guard count > 0 else { return nil }
var workers: [AsrManager] = [manager]
if count == 1 {
return workers
}
for _ in 1..<count {
guard let clone = await manager.makeWorkerClone() else {
return nil
}
workers.append(clone)
}
logger.debug("ChunkProcessor using worker pool of size \(workers.count)")
return workers
}

private func readSamples(offset: Int, count: Int) throws -> [Float] {
var buffer = [Float](repeating: 0, count: count)
try buffer.withUnsafeMutableBufferPointer { pointer in
Expand All @@ -170,13 +238,14 @@ struct ChunkProcessor {
return buffer
}

private func transcribeChunk(
private static func transcribeChunk(
samples: [Float],
contextSamples: Int,
chunkStart: Int,
isLastChunk: Bool,
using manager: AsrManager,
decoderState: inout TdtDecoderState
decoderState: inout TdtDecoderState,
maxModelSamples: Int
) async throws -> (tokens: [Int], timestamps: [Int], confidences: [Float], durations: [Int]) {
guard !samples.isEmpty else { return ([], [], [], []) }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ internal struct TdtDecoderV2 {
return ASRConfig(
sampleRate: config.sampleRate,
tdtConfig: adaptedTdt,
encoderHiddenSize: config.encoderHiddenSize
encoderHiddenSize: config.encoderHiddenSize,
parallelChunkConcurrency: config.parallelChunkConcurrency,
streamingEnabled: config.streamingEnabled,
streamingThreshold: config.streamingThreshold
)
}
}
19 changes: 19 additions & 0 deletions Tests/FluidAudioTests/ASR/Parakeet/ASRConfigTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import XCTest

@testable import FluidAudio

final class ASRConfigTests: XCTestCase {

func testDefaultParallelChunkConcurrency() {
XCTAssertEqual(ASRConfig.default.parallelChunkConcurrency, 4)
}

func testParallelChunkConcurrencyClampsToAtLeastOne() {
XCTAssertEqual(ASRConfig(parallelChunkConcurrency: 0).parallelChunkConcurrency, 1)
XCTAssertEqual(ASRConfig(parallelChunkConcurrency: -3).parallelChunkConcurrency, 1)
}

func testParallelChunkConcurrencyPreservesExplicitValue() {
XCTAssertEqual(ASRConfig(parallelChunkConcurrency: 6).parallelChunkConcurrency, 6)
}
}
Loading