diff --git a/Documentation/ASR/DirectoryStructure.md b/Documentation/ASR/DirectoryStructure.md index b2ab706f9..798ec3cda 100644 --- a/Documentation/ASR/DirectoryStructure.md +++ b/Documentation/ASR/DirectoryStructure.md @@ -74,7 +74,12 @@ ASR/ │ │ ├── TdtDecoderState.swift │ │ ├── TdtDecoderV2.swift │ │ ├── TdtDecoderV3.swift -│ │ └── TdtHypothesis.swift +│ │ ├── TdtHypothesis.swift +│ │ ├── TdtModelInference.swift (Model inference operations) +│ │ ├── TdtJointDecision.swift (Joint network decision structure) +│ │ ├── TdtJointInputProvider.swift (Reusable feature provider) +│ │ ├── TdtDurationMapping.swift (Duration bin mapping utilities) +│ │ └── TdtFrameNavigation.swift (Frame position calculations) │ │ │ ├── SlidingWindow/ │ │ ├── SlidingWindowAsrManager.swift diff --git a/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtDecoderV3.swift b/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtDecoderV3.swift index 54ccc34fe..816f18f41 100644 --- a/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtDecoderV3.swift +++ b/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtDecoderV3.swift @@ -32,50 +32,15 @@ import OSLog internal struct TdtDecoderV3 { - /// Joint model decision for a single encoder/decoder step. - private struct JointDecision { - let token: Int - let probability: Float - let durationBin: Int - } - private let logger = AppLogger(category: "TDT") private let config: ASRConfig - private let predictionOptions = AsrModels.optimizedPredictionOptions() + private let modelInference = TdtModelInference() // Parakeet‑TDT‑v3: duration head has 5 bins mapping directly to frame advances init(config: ASRConfig) { self.config = config } - /// Reusable input provider that holds references to preallocated - /// encoder and decoder step tensors for the joint model. - private final class ReusableJointInput: NSObject, MLFeatureProvider { - let encoderStep: MLMultiArray - let decoderStep: MLMultiArray - - init(encoderStep: MLMultiArray, decoderStep: MLMultiArray) { - self.encoderStep = encoderStep - self.decoderStep = decoderStep - super.init() - } - - var featureNames: Set { - ["encoder_step", "decoder_step"] - } - - func featureValue(for featureName: String) -> MLFeatureValue? { - switch featureName { - case "encoder_step": - return MLFeatureValue(multiArray: encoderStep) - case "decoder_step": - return MLFeatureValue(multiArray: decoderStep) - default: - return nil - } - } - } - /// Execute TDT decoding and return tokens with emission timestamps /// /// This is the main entry point for the decoder. It processes encoder frames sequentially, @@ -128,40 +93,22 @@ internal struct TdtDecoderV3 { // timeIndices: Current position in encoder frames (advances by duration) // timeJump: Tracks overflow when we process beyond current chunk (for streaming) // contextFrameAdjustment: Adjusts for adaptive context overlap - var timeIndices: Int - if let prevTimeJump = decoderState.timeJump { - // Streaming continuation: timeJump represents decoder position beyond previous chunk - // For the new chunk, we need to account for: - // 1. How far the decoder advanced past the previous chunk (prevTimeJump) - // 2. The overlap/context between chunks (contextFrameAdjustment) - // - // If prevTimeJump > 0: decoder went past previous chunk's frames - // If contextFrameAdjustment < 0: decoder should skip frames (overlap with previous chunk) - // If contextFrameAdjustment > 0: decoder should start later (adaptive context) - // Net position = prevTimeJump + contextFrameAdjustment (add adjustment to decoder position) - - // SPECIAL CASE: When prevTimeJump = 0 and contextFrameAdjustment = 0, - // decoder finished exactly at boundary but chunk has physical overlap - // Need to skip the overlap frames to avoid re-processing - if prevTimeJump == 0 && contextFrameAdjustment == 0 { - // Skip standard overlap (2.0s = 25 frames at 0.08s per frame) - timeIndices = 25 - } else { - timeIndices = max(0, prevTimeJump + contextFrameAdjustment) - } + var timeIndices = TdtFrameNavigation.calculateInitialTimeIndices( + timeJump: decoderState.timeJump, + contextFrameAdjustment: contextFrameAdjustment + ) - } else { - // First chunk: start from beginning, accounting for any context frames that were already processed - timeIndices = contextFrameAdjustment - } - // Use the minimum of encoder sequence length and actual audio frames to avoid processing padding - let effectiveSequenceLength = min(encoderSequenceLength, actualAudioFrames) + let navigationState = TdtFrameNavigation.initializeNavigationState( + timeIndices: timeIndices, + encoderSequenceLength: encoderSequenceLength, + actualAudioFrames: actualAudioFrames + ) + let effectiveSequenceLength = navigationState.effectiveSequenceLength + var safeTimeIndices = navigationState.safeTimeIndices + let lastTimestep = navigationState.lastTimestep + var activeMask = navigationState.activeMask - // Key variables for frame navigation: - var safeTimeIndices = min(timeIndices, effectiveSequenceLength - 1) // Bounds-checked index var timeIndicesCurrentLabels = timeIndices // Frame where current token was emitted - var activeMask = timeIndices < effectiveSequenceLength // Start processing only if we haven't exceeded bounds - let lastTimestep = effectiveSequenceLength - 1 // Maximum valid frame index // If timeJump puts us beyond the available frames, return empty if timeIndices >= effectiveSequenceLength { @@ -183,7 +130,7 @@ internal struct TdtDecoderV3 { shape: [1, NSNumber(value: decoderHidden), 1], dataType: .float32 ) - let jointInput = ReusableJointInput(encoderStep: reusableEncoderStep, decoderStep: reusableDecoderStep) + let jointInput = ReusableJointInputProvider(encoderStep: reusableEncoderStep, decoderStep: reusableDecoderStep) // Cache frequently used stride for copying encoder frames let encDestStride = reusableEncoderStep.strides.map { $0.intValue }[1] let encDestPtr = reusableEncoderStep.dataPointer.bindMemory(to: Float.self, capacity: encoderHidden) @@ -206,7 +153,7 @@ internal struct TdtDecoderV3 { // Note: In RNN-T/TDT, we use blank token as SOS if decoderState.predictorOutput == nil && hypothesis.lastToken == nil { let sos = config.tdtConfig.blankId // blank=8192 serves as SOS - let primed = try runDecoder( + let primed = try modelInference.runDecoder( token: sos, state: decoderState, model: decoderModel, @@ -226,10 +173,12 @@ internal struct TdtDecoderV3 { var emissionsAtThisTimestamp = 0 let maxSymbolsPerStep = config.tdtConfig.maxSymbolsPerStep // Usually 5-10 var tokensProcessedThisChunk = 0 // Track tokens per chunk to prevent runaway decoding + var iterCount = 0 // ===== MAIN DECODING LOOP ===== // Process each encoder frame until we've consumed all audio while activeMask { + iterCount += 1 try Task.checkCancellation() // Use last emitted token for decoder context, or blank if starting var label = hypothesis.lastToken ?? config.tdtConfig.blankId @@ -247,7 +196,7 @@ internal struct TdtDecoderV3 { decoderResult = (output: provider, newState: stateToUse) } else { // No cache - run decoder LSTM - decoderResult = try runDecoder( + decoderResult = try modelInference.runDecoder( token: label, state: stateToUse, model: decoderModel, @@ -259,10 +208,10 @@ internal struct TdtDecoderV3 { // Prepare decoder projection once and reuse for inner blank loop let decoderProjection = try extractFeatureValue( from: decoderResult.output, key: "decoder", errorMessage: "Invalid decoder output") - try normalizeDecoderProjection(decoderProjection, into: reusableDecoderStep) + try modelInference.normalizeDecoderProjection(decoderProjection, into: reusableDecoderStep) // Run joint network with preallocated inputs - let decision = try runJointPrepared( + let decision = try modelInference.runJointPrepared( encoderFrames: encoderFrames, timeIndex: safeTimeIndices, preparedDecoderStep: reusableDecoderStep, @@ -278,11 +227,11 @@ internal struct TdtDecoderV3 { // Predict token (what to emit) and duration (how many frames to skip) label = decision.token - var score = clampProbability(decision.probability) + var score = TdtDurationMapping.clampProbability(decision.probability) // Map duration bin to actual frame count // durationBins typically = [0,1,2,3,4] meaning skip 0-4 frames - var duration = try mapDurationBin( + var duration = try TdtDurationMapping.mapDurationBin( decision.durationBin, durationBins: config.tdtConfig.durationBins) let blankId = config.tdtConfig.blankId // 8192 for v3 models @@ -329,12 +278,14 @@ internal struct TdtDecoderV3 { // - Avoids expensive LSTM computations for silence frames // - Maintains linguistic continuity across gaps in speech // - Speeds up processing by 2-3x for audio with silence + var innerLoopCount = 0 while advanceMask { + innerLoopCount += 1 try Task.checkCancellation() timeIndicesCurrentLabels = timeIndices // INTENTIONAL: Reusing prepared decoder step from outside loop - let innerDecision = try runJointPrepared( + let innerDecision = try modelInference.runJointPrepared( encoderFrames: encoderFrames, timeIndex: safeTimeIndices, preparedDecoderStep: reusableDecoderStep, @@ -349,8 +300,8 @@ internal struct TdtDecoderV3 { ) label = innerDecision.token - score = clampProbability(innerDecision.probability) - duration = try mapDurationBin( + score = TdtDurationMapping.clampProbability(innerDecision.probability) + duration = try TdtDurationMapping.mapDurationBin( innerDecision.durationBin, durationBins: config.tdtConfig.durationBins) blankMask = (label == blankId) @@ -360,7 +311,8 @@ internal struct TdtDecoderV3 { duration = 1 } - // Advance and check if we should continue the inner loop + // Advance by duration regardless of blank/non-blank + // This is the ORIGINAL and CORRECT logic timeIndices += duration safeTimeIndices = min(timeIndices, lastTimestep) activeMask = timeIndices < effectiveSequenceLength @@ -389,7 +341,7 @@ internal struct TdtDecoderV3 { // Only non-blank tokens update the decoder - this is key! // NOTE: We update the decoder state regardless of whether we emit the token // to maintain proper language model context across chunk boundaries - let step = try runDecoder( + let step = try modelInference.runDecoder( token: label, state: decoderResult.newState, model: decoderModel, @@ -447,7 +399,7 @@ internal struct TdtDecoderV3 { ]) decoderResult = (output: provider, newState: stateToUse) } else { - decoderResult = try runDecoder( + decoderResult = try modelInference.runDecoder( token: lastToken, state: stateToUse, model: decoderModel, @@ -467,9 +419,9 @@ internal struct TdtDecoderV3 { // Prepare decoder projection into reusable buffer (if not already) let finalProjection = try extractFeatureValue( from: decoderResult.output, key: "decoder", errorMessage: "Invalid decoder output") - try normalizeDecoderProjection(finalProjection, into: reusableDecoderStep) + try modelInference.normalizeDecoderProjection(finalProjection, into: reusableDecoderStep) - let decision = try runJointPrepared( + let decision = try modelInference.runJointPrepared( encoderFrames: encoderFrames, timeIndex: frameIndex, preparedDecoderStep: reusableDecoderStep, @@ -484,10 +436,10 @@ internal struct TdtDecoderV3 { ) let token = decision.token - let score = clampProbability(decision.probability) + let score = TdtDurationMapping.clampProbability(decision.probability) // Also get duration for proper timestamp calculation - let duration = try mapDurationBin( + let duration = try TdtDurationMapping.mapDurationBin( decision.durationBin, durationBins: config.tdtConfig.durationBins) if token == config.tdtConfig.blankId { @@ -507,7 +459,7 @@ internal struct TdtDecoderV3 { hypothesis.lastToken = token // Update decoder state - let step = try runDecoder( + let step = try modelInference.runDecoder( token: token, state: decoderResult.newState, model: decoderModel, @@ -536,213 +488,24 @@ internal struct TdtDecoderV3 { // Clear cached predictor output if ending with punctuation // This prevents punctuation from being duplicated at chunk boundaries - if let lastToken = hypothesis.lastToken { - let punctuationTokens = [7883, 7952, 7948] // period, question, exclamation - if punctuationTokens.contains(lastToken) { - decoderState.predictorOutput = nil - // Keep lastToken for linguistic context - deduplication handles duplicates at higher level - } + if let lastToken = hypothesis.lastToken, + ASRConstants.punctuationTokens.contains(lastToken) + { + decoderState.predictorOutput = nil + // Keep lastToken for linguistic context - deduplication handles duplicates at higher level } - // Always store time jump for streaming: how far beyond this chunk we've processed - // Used to align timestamps when processing next chunk - // Formula: timeJump = finalPosition - effectiveFrames - let finalTimeJump = timeIndices - effectiveSequenceLength - decoderState.timeJump = finalTimeJump - - // For the last chunk, clear timeJump since there are no more chunks - if isLastChunk { - decoderState.timeJump = nil - } + // Calculate final timeJump for streaming continuation + decoderState.timeJump = TdtFrameNavigation.calculateFinalTimeJump( + currentTimeIndices: timeIndices, + effectiveSequenceLength: effectiveSequenceLength, + isLastChunk: isLastChunk + ) // No filtering at decoder level - let post-processing handle deduplication return hypothesis } - /// Decoder execution - private func runDecoder( - token: Int, - state: TdtDecoderState, - model: MLModel, - targetArray: MLMultiArray, - targetLengthArray: MLMultiArray - ) throws -> (output: MLFeatureProvider, newState: TdtDecoderState) { - - // Reuse pre-allocated arrays - targetArray[0] = NSNumber(value: token) - // targetLengthArray[0] is already set to 1 and never changes - - let input = try MLDictionaryFeatureProvider(dictionary: [ - "targets": MLFeatureValue(multiArray: targetArray), - "target_length": MLFeatureValue(multiArray: targetLengthArray), - "h_in": MLFeatureValue(multiArray: state.hiddenState), - "c_in": MLFeatureValue(multiArray: state.cellState), - ]) - - // Reuse decoder state output buffers to avoid CoreML allocating new ones - // Note: outputBackings expects raw backing objects (MLMultiArray / CVPixelBuffer) - predictionOptions.outputBackings = [ - "h_out": state.hiddenState, - "c_out": state.cellState, - ] - - let output = try model.prediction( - from: input, - options: predictionOptions - ) - - var newState = state - newState.update(from: output) - - return (output, newState) - } - - /// Joint network execution with zero-copy - /// Joint network execution using preallocated input arrays and a reusable provider. - private func runJointPrepared( - encoderFrames: EncoderFrameView, - timeIndex: Int, - preparedDecoderStep: MLMultiArray, - model: MLModel, - encoderStep: MLMultiArray, - encoderDestPtr: UnsafeMutablePointer, - encoderDestStride: Int, - inputProvider: MLFeatureProvider, - tokenIdBacking: MLMultiArray, - tokenProbBacking: MLMultiArray, - durationBacking: MLMultiArray - ) throws -> JointDecision { - - // Fill encoder step with the requested frame - try encoderFrames.copyFrame(at: timeIndex, into: encoderDestPtr, destinationStride: encoderDestStride) - - // Prefetch arrays for ANE - encoderStep.prefetchToNeuralEngine() - preparedDecoderStep.prefetchToNeuralEngine() - - // Reuse tiny output tensors for joint prediction (provide raw MLMultiArray backings) - predictionOptions.outputBackings = [ - "token_id": tokenIdBacking, - "token_prob": tokenProbBacking, - "duration": durationBacking, - ] - - // Execute joint network using the reusable provider - let output = try model.prediction( - from: inputProvider, - options: predictionOptions - ) - - let tokenIdArray = try extractFeatureValue( - from: output, key: "token_id", errorMessage: "Joint decision output missing token_id") - let tokenProbArray = try extractFeatureValue( - from: output, key: "token_prob", errorMessage: "Joint decision output missing token_prob") - let durationArray = try extractFeatureValue( - from: output, key: "duration", errorMessage: "Joint decision output missing duration") - - guard tokenIdArray.count == 1, - tokenProbArray.count == 1, - durationArray.count == 1 - else { - throw ASRError.processingFailed("Joint decision returned unexpected tensor shapes") - } - - let tokenPointer = tokenIdArray.dataPointer.bindMemory(to: Int32.self, capacity: tokenIdArray.count) - let token = Int(tokenPointer[0]) - let probPointer = tokenProbArray.dataPointer.bindMemory(to: Float.self, capacity: tokenProbArray.count) - let probability = probPointer[0] - let durationPointer = durationArray.dataPointer.bindMemory(to: Int32.self, capacity: durationArray.count) - let durationBin = Int(durationPointer[0]) - - return JointDecision(token: token, probability: probability, durationBin: durationBin) - } - - /// Normalize decoder projection into [1, hiddenSize, 1] layout via BLAS copy. - /// If `destination` is provided, writes into it (hot path). Otherwise allocates a new array. - @discardableResult - private func normalizeDecoderProjection( - _ projection: MLMultiArray, - into destination: MLMultiArray? = nil - ) throws -> MLMultiArray { - let hiddenSize = ASRConstants.decoderHiddenSize - let shape = projection.shape.map { $0.intValue } - - guard shape.count == 3 else { - throw ASRError.processingFailed("Invalid decoder projection rank: \(shape)") - } - guard shape[0] == 1 else { - throw ASRError.processingFailed("Unsupported decoder batch dimension: \(shape[0])") - } - guard projection.dataType == .float32 else { - throw ASRError.processingFailed("Unsupported decoder projection type: \(projection.dataType)") - } - - let hiddenAxis: Int - if shape[2] == hiddenSize { - hiddenAxis = 2 - } else if shape[1] == hiddenSize { - hiddenAxis = 1 - } else { - throw ASRError.processingFailed("Decoder projection hidden size mismatch: \(shape)") - } - - let timeAxis = (0...2).first { $0 != hiddenAxis && $0 != 0 } ?? 1 - guard shape[timeAxis] == 1 else { - throw ASRError.processingFailed("Decoder projection time axis must be 1: \(shape)") - } - - let out: MLMultiArray - if let destination { - let outShape = destination.shape.map { $0.intValue } - guard destination.dataType == .float32, outShape.count == 3, outShape[0] == 1, - outShape[2] == 1, outShape[1] == hiddenSize - else { - throw ASRError.processingFailed( - "Prepared decoder step shape mismatch: \(destination.shapeString)") - } - out = destination - } else { - out = try ANEMemoryUtils.createAlignedArray( - shape: [1, NSNumber(value: hiddenSize), 1], - dataType: .float32 - ) - } - - let destPtr = out.dataPointer.bindMemory(to: Float.self, capacity: hiddenSize) - let destStrides = out.strides.map { $0.intValue } - let destHiddenStride = destStrides[1] - let destStrideCblas = try makeBlasIndex(destHiddenStride, label: "Decoder destination stride") - - let sourcePtr = projection.dataPointer.bindMemory(to: Float.self, capacity: projection.count) - let strides = projection.strides.map { $0.intValue } - let hiddenStride = strides[hiddenAxis] - let timeStride = strides[timeAxis] - let batchStride = strides[0] - - var baseOffset = 0 - if batchStride < 0 { baseOffset += (shape[0] - 1) * batchStride } - if timeStride < 0 { baseOffset += (shape[timeAxis] - 1) * timeStride } - - let minOffset = hiddenStride < 0 ? hiddenStride * (hiddenSize - 1) : 0 - let maxOffset = hiddenStride > 0 ? hiddenStride * (hiddenSize - 1) : 0 - let lowerBound = baseOffset + minOffset - let upperBound = baseOffset + maxOffset - guard lowerBound >= 0 && upperBound < projection.count else { - throw ASRError.processingFailed("Decoder projection stride exceeds buffer bounds") - } - - let startPtr = sourcePtr.advanced(by: baseOffset) - if hiddenStride == 1 && destHiddenStride == 1 { - destPtr.update(from: startPtr, count: hiddenSize) - } else { - let count = try makeBlasIndex(hiddenSize, label: "Decoder projection length") - let stride = try makeBlasIndex(hiddenStride, label: "Decoder projection stride") - cblas_scopy(count, startPtr, stride, destPtr, destStrideCblas) - } - - return out - } - /// Update hypothesis with new token internal func updateHypothesis( _ hypothesis: inout TdtHypothesis, @@ -763,17 +526,6 @@ internal struct TdtDecoderV3 { } // MARK: - Private Helper Methods - private func mapDurationBin(_ binIndex: Int, durationBins: [Int]) throws -> Int { - guard binIndex >= 0 && binIndex < durationBins.count else { - throw ASRError.processingFailed("Duration bin index out of range: \(binIndex)") - } - return durationBins[binIndex] - } - - private func clampProbability(_ value: Float) -> Float { - guard value.isFinite else { return 0 } - return min(max(value, 0), 1) - } internal func extractEncoderTimeStep( _ encoderOutput: MLMultiArray, timeIndex: Int @@ -838,7 +590,7 @@ internal struct TdtDecoderV3 { let decoderProjection = try extractFeatureValue( from: decoderOutput, key: "decoder", errorMessage: "Invalid decoder output") - let normalizedDecoder = try normalizeDecoderProjection(decoderProjection) + let normalizedDecoder = try modelInference.normalizeDecoderProjection(decoderProjection) return try MLDictionaryFeatureProvider(dictionary: [ "encoder_step": MLFeatureValue(multiArray: encoderStep), diff --git a/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtDurationMapping.swift b/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtDurationMapping.swift new file mode 100644 index 000000000..89470e8fe --- /dev/null +++ b/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtDurationMapping.swift @@ -0,0 +1,32 @@ +import Foundation + +/// Utilities for mapping TDT duration bins to encoder frame advances. +enum TdtDurationMapping { + + /// Map a duration bin index to actual encoder frames to advance. + /// + /// Parakeet-TDT models use a discrete duration head with bins that map to frame advances. + /// - v3 models: 5 bins [1, 2, 3, 4, 5] (direct 1:1 mapping) + /// - v2 models: May have different bin configurations + /// + /// - Parameters: + /// - binIndex: The duration bin index from the model output + /// - durationBins: Array mapping bin indices to frame advances + /// - Returns: Number of encoder frames to advance + /// - Throws: `ASRError.invalidDurationBin` if binIndex is out of range + static func mapDurationBin(_ binIndex: Int, durationBins: [Int]) throws -> Int { + guard binIndex >= 0 && binIndex < durationBins.count else { + throw ASRError.processingFailed("Duration bin index out of range: \(binIndex)") + } + return durationBins[binIndex] + } + + /// Clamp probability to valid range [0, 1] to handle edge cases. + /// + /// - Parameter value: Raw probability value (may be slightly outside [0,1] due to float precision or NaN) + /// - Returns: Clamped probability in [0, 1], or 0 if value is not finite + static func clampProbability(_ value: Float) -> Float { + guard value.isFinite else { return 0 } + return max(0.0, min(1.0, value)) + } +} diff --git a/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtFrameNavigation.swift b/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtFrameNavigation.swift new file mode 100644 index 000000000..01d355599 --- /dev/null +++ b/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtFrameNavigation.swift @@ -0,0 +1,106 @@ +import Foundation + +/// Frame navigation utilities for TDT decoding. +/// +/// Handles time index calculations for streaming ASR with chunk-based processing, +/// including timeJump management for decoder position tracking across chunks. +internal struct TdtFrameNavigation { + + /// Calculate initial time indices for chunk processing. + /// + /// Determines where to start processing in the current chunk based on: + /// - Previous timeJump (how far past the previous chunk the decoder advanced) + /// - Context frame adjustment (adaptive overlap compensation) + /// + /// - Parameters: + /// - timeJump: Optional timeJump from previous chunk (nil for first chunk) + /// - contextFrameAdjustment: Frame offset for adaptive context + /// + /// - Returns: Starting frame index for this chunk + static func calculateInitialTimeIndices( + timeJump: Int?, + contextFrameAdjustment: Int + ) -> Int { + // First chunk: start from beginning, accounting for any context frames already processed + guard let prevTimeJump = timeJump else { + return contextFrameAdjustment + } + + // Streaming continuation: timeJump represents decoder position beyond previous chunk + // For the new chunk, we need to account for: + // 1. How far the decoder advanced past the previous chunk (prevTimeJump) + // 2. The overlap/context between chunks (contextFrameAdjustment) + // + // If prevTimeJump > 0: decoder went past previous chunk's frames + // If contextFrameAdjustment < 0: decoder should skip frames (overlap with previous chunk) + // If contextFrameAdjustment > 0: decoder should start later (adaptive context) + // Net position = prevTimeJump + contextFrameAdjustment (add adjustment to decoder position) + + // SPECIAL CASE: When prevTimeJump = 0 and contextFrameAdjustment = 0, + // decoder finished exactly at boundary but chunk has physical overlap + // Need to skip the overlap frames to avoid re-processing + if prevTimeJump == 0 && contextFrameAdjustment == 0 { + // Skip standard overlap (2.0s = 25 frames at 0.08s per frame) + return ASRConstants.standardOverlapFrames + } + + // Normal streaming continuation + return max(0, prevTimeJump + contextFrameAdjustment) + } + + /// Initialize frame navigation state for decoding loop. + /// + /// - Parameters: + /// - timeIndices: Initial time index calculated from timeJump + /// - encoderSequenceLength: Total encoder frames in this chunk + /// - actualAudioFrames: Actual audio frames (excluding padding) + /// + /// - Returns: Tuple of navigation state values + static func initializeNavigationState( + timeIndices: Int, + encoderSequenceLength: Int, + actualAudioFrames: Int + ) -> ( + effectiveSequenceLength: Int, + safeTimeIndices: Int, + lastTimestep: Int, + activeMask: Bool + ) { + // Use the minimum of encoder sequence length and actual audio frames to avoid processing padding + let effectiveSequenceLength = min(encoderSequenceLength, actualAudioFrames) + + // Key variables for frame navigation: + let safeTimeIndices = min(timeIndices, effectiveSequenceLength - 1) // Bounds-checked index + let lastTimestep = effectiveSequenceLength - 1 // Maximum valid frame index + let activeMask = timeIndices < effectiveSequenceLength // Start processing only if we haven't exceeded bounds + + return (effectiveSequenceLength, safeTimeIndices, lastTimestep, activeMask) + } + + /// Calculate final timeJump for streaming continuation. + /// + /// TimeJump tracks how far beyond the current chunk the decoder has advanced, + /// which is used to properly position the decoder in the next chunk. + /// + /// - Parameters: + /// - currentTimeIndices: Final time index after processing + /// - effectiveSequenceLength: Number of valid frames in this chunk + /// - isLastChunk: Whether this is the last chunk (no more chunks to process) + /// + /// - Returns: TimeJump value (nil for last chunk, otherwise offset from chunk boundary) + static func calculateFinalTimeJump( + currentTimeIndices: Int, + effectiveSequenceLength: Int, + isLastChunk: Bool + ) -> Int? { + // For the last chunk, clear timeJump since there are no more chunks + if isLastChunk { + return nil + } + + // Always store time jump for streaming: how far beyond this chunk we've processed + // Used to align timestamps when processing next chunk + // Formula: timeJump = finalPosition - effectiveFrames + return currentTimeIndices - effectiveSequenceLength + } +} diff --git a/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtJointDecision.swift b/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtJointDecision.swift new file mode 100644 index 000000000..d6f412433 --- /dev/null +++ b/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtJointDecision.swift @@ -0,0 +1,14 @@ +/// Joint model decision for a single encoder/decoder step. +/// +/// Represents the output of the TDT joint network which combines encoder and decoder features +/// to predict the next token, its probability, and how many audio frames to skip. +internal struct TdtJointDecision { + /// Predicted token ID from vocabulary + let token: Int + + /// Softmax probability for this token + let probability: Float + + /// Duration bin index (maps to number of encoder frames to skip) + let durationBin: Int +} diff --git a/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtJointInputProvider.swift b/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtJointInputProvider.swift new file mode 100644 index 000000000..e90e45ecc --- /dev/null +++ b/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtJointInputProvider.swift @@ -0,0 +1,50 @@ +import CoreML +import Foundation + +/// Reusable input provider for TDT joint model inference. +/// +/// This class holds pre-allocated MLMultiArray tensors for encoder and decoder features, +/// allowing zero-copy joint network execution. By reusing the same arrays across +/// inference calls, we avoid repeated allocations and improve ANE performance. +/// +/// Usage: +/// ```swift +/// let provider = ReusableJointInputProvider( +/// encoderStep: encoderStepArray, // Shape: [1, 1024] +/// decoderStep: decoderStepArray // Shape: [1, 640] +/// ) +/// let output = try jointModel.prediction(from: provider) +/// ``` +internal final class ReusableJointInputProvider: NSObject, MLFeatureProvider { + /// Encoder feature tensor (shape: [1, hidden_dim]) + let encoderStep: MLMultiArray + + /// Decoder feature tensor (shape: [1, decoder_dim]) + let decoderStep: MLMultiArray + + /// Initialize with pre-allocated encoder and decoder step tensors. + /// + /// - Parameters: + /// - encoderStep: MLMultiArray for encoder features (typically [1, 1024]) + /// - decoderStep: MLMultiArray for decoder features (typically [1, 640]) + init(encoderStep: MLMultiArray, decoderStep: MLMultiArray) { + self.encoderStep = encoderStep + self.decoderStep = decoderStep + super.init() + } + + var featureNames: Set { + ["encoder_step", "decoder_step"] + } + + func featureValue(for featureName: String) -> MLFeatureValue? { + switch featureName { + case "encoder_step": + return MLFeatureValue(multiArray: encoderStep) + case "decoder_step": + return MLFeatureValue(multiArray: decoderStep) + default: + return nil + } + } +} diff --git a/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtModelInference.swift b/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtModelInference.swift new file mode 100644 index 000000000..3bf609536 --- /dev/null +++ b/Sources/FluidAudio/ASR/Parakeet/Decoder/TdtModelInference.swift @@ -0,0 +1,226 @@ +import Accelerate +import CoreML +import Foundation + +/// Model inference operations for TDT decoding. +/// +/// Encapsulates execution of decoder LSTM, joint network, and decoder projection normalization. +/// These operations are separated from the main decoding loop to improve testability and clarity. +internal struct TdtModelInference { + private let predictionOptions: MLPredictionOptions + + init() { + self.predictionOptions = AsrModels.optimizedPredictionOptions() + } + + /// Execute decoder LSTM with state caching. + /// + /// - Parameters: + /// - token: Token ID to decode + /// - state: Current decoder LSTM state + /// - model: Decoder MLModel + /// - targetArray: Pre-allocated array for token input + /// - targetLengthArray: Pre-allocated array for length (always 1) + /// + /// - Returns: Tuple of (output features, updated state) + func runDecoder( + token: Int, + state: TdtDecoderState, + model: MLModel, + targetArray: MLMultiArray, + targetLengthArray: MLMultiArray + ) throws -> (output: MLFeatureProvider, newState: TdtDecoderState) { + + // Reuse pre-allocated arrays + targetArray[0] = NSNumber(value: token) + // targetLengthArray[0] is already set to 1 and never changes + + let input = try MLDictionaryFeatureProvider(dictionary: [ + "targets": MLFeatureValue(multiArray: targetArray), + "target_length": MLFeatureValue(multiArray: targetLengthArray), + "h_in": MLFeatureValue(multiArray: state.hiddenState), + "c_in": MLFeatureValue(multiArray: state.cellState), + ]) + + // Reuse decoder state output buffers to avoid CoreML allocating new ones + // Note: outputBackings expects raw backing objects (MLMultiArray / CVPixelBuffer) + predictionOptions.outputBackings = [ + "h_out": state.hiddenState, + "c_out": state.cellState, + ] + + let output = try model.prediction( + from: input, + options: predictionOptions + ) + + var newState = state + newState.update(from: output) + + return (output, newState) + } + + /// Execute joint network with zero-copy and ANE optimization. + /// + /// - Parameters: + /// - encoderFrames: View into encoder output tensor + /// - timeIndex: Frame index to process + /// - preparedDecoderStep: Normalized decoder projection + /// - model: Joint MLModel + /// - encoderStep: Pre-allocated encoder step array + /// - encoderDestPtr: Pointer for encoder frame copy + /// - encoderDestStride: Stride for encoder copy + /// - inputProvider: Reusable feature provider + /// - tokenIdBacking: Pre-allocated output for token ID + /// - tokenProbBacking: Pre-allocated output for probability + /// - durationBacking: Pre-allocated output for duration + /// + /// - Returns: Joint decision (token, probability, duration bin) + func runJointPrepared( + encoderFrames: EncoderFrameView, + timeIndex: Int, + preparedDecoderStep: MLMultiArray, + model: MLModel, + encoderStep: MLMultiArray, + encoderDestPtr: UnsafeMutablePointer, + encoderDestStride: Int, + inputProvider: MLFeatureProvider, + tokenIdBacking: MLMultiArray, + tokenProbBacking: MLMultiArray, + durationBacking: MLMultiArray + ) throws -> TdtJointDecision { + + // Fill encoder step with the requested frame + try encoderFrames.copyFrame(at: timeIndex, into: encoderDestPtr, destinationStride: encoderDestStride) + + // Prefetch arrays for ANE + encoderStep.prefetchToNeuralEngine() + preparedDecoderStep.prefetchToNeuralEngine() + + // Reuse tiny output tensors for joint prediction (provide raw MLMultiArray backings) + predictionOptions.outputBackings = [ + "token_id": tokenIdBacking, + "token_prob": tokenProbBacking, + "duration": durationBacking, + ] + + // Execute joint network using the reusable provider + let output = try model.prediction( + from: inputProvider, + options: predictionOptions + ) + + let tokenIdArray = try extractFeatureValue( + from: output, key: "token_id", errorMessage: "Joint decision output missing token_id") + let tokenProbArray = try extractFeatureValue( + from: output, key: "token_prob", errorMessage: "Joint decision output missing token_prob") + let durationArray = try extractFeatureValue( + from: output, key: "duration", errorMessage: "Joint decision output missing duration") + + guard tokenIdArray.count == 1, + tokenProbArray.count == 1, + durationArray.count == 1 + else { + throw ASRError.processingFailed("Joint decision returned unexpected tensor shapes") + } + + let tokenPointer = tokenIdArray.dataPointer.bindMemory(to: Int32.self, capacity: tokenIdArray.count) + let token = Int(tokenPointer[0]) + let probPointer = tokenProbArray.dataPointer.bindMemory(to: Float.self, capacity: tokenProbArray.count) + let probability = probPointer[0] + let durationPointer = durationArray.dataPointer.bindMemory(to: Int32.self, capacity: durationArray.count) + let durationBin = Int(durationPointer[0]) + + return TdtJointDecision(token: token, probability: probability, durationBin: durationBin) + } + + /// Normalize decoder projection into [1, hiddenSize, 1] layout via BLAS copy. + /// + /// CoreML decoder outputs can have varying layouts ([1, 1, 640] or [1, 640, 1]). + /// This function normalizes to the joint network's expected input format using + /// efficient BLAS operations to handle arbitrary strides. + /// + /// - Parameters: + /// - projection: Decoder output projection (any 3D layout with hiddenSize dimension) + /// - destination: Optional pre-allocated destination array (for hot path) + /// + /// - Returns: Normalized array in [1, hiddenSize, 1] format + @discardableResult + func normalizeDecoderProjection( + _ projection: MLMultiArray, + into destination: MLMultiArray? = nil + ) throws -> MLMultiArray { + let hiddenSize = ASRConstants.decoderHiddenSize + let shape = projection.shape.map { $0.intValue } + + guard shape.count == 3 else { + throw ASRError.processingFailed("Invalid decoder projection rank: \(shape)") + } + guard shape[0] == 1 else { + throw ASRError.processingFailed("Unsupported decoder batch dimension: \(shape[0])") + } + guard projection.dataType == .float32 else { + throw ASRError.processingFailed("Unsupported decoder projection type: \(projection.dataType)") + } + + let hiddenAxis: Int + if shape[2] == hiddenSize { + hiddenAxis = 2 + } else if shape[1] == hiddenSize { + hiddenAxis = 1 + } else { + throw ASRError.processingFailed("Decoder projection hidden size mismatch: \(shape)") + } + + let timeAxis = (0...2).first { $0 != hiddenAxis && $0 != 0 } ?? 1 + guard shape[timeAxis] == 1 else { + throw ASRError.processingFailed("Decoder projection time axis must be 1: \(shape)") + } + + let out: MLMultiArray + if let destination { + let outShape = destination.shape.map { $0.intValue } + guard destination.dataType == .float32, outShape.count == 3, outShape[0] == 1, + outShape[2] == 1, outShape[1] == hiddenSize + else { + throw ASRError.processingFailed( + "Prepared decoder step shape mismatch: \(destination.shapeString)") + } + out = destination + } else { + out = try ANEMemoryUtils.createAlignedArray( + shape: [1, NSNumber(value: hiddenSize), 1], + dataType: .float32 + ) + } + + let strides = projection.strides.map { $0.intValue } + let hiddenStride = strides[hiddenAxis] + + let dataPointer = projection.dataPointer.bindMemory(to: Float.self, capacity: projection.count) + let startPtr = dataPointer.advanced(by: 0) + + let destPtr = out.dataPointer.bindMemory(to: Float.self, capacity: hiddenSize) + let destStrides = out.strides.map { $0.intValue } + let destHiddenStride = destStrides[1] + let destStrideCblas = try makeBlasIndex(destHiddenStride, label: "Decoder destination stride") + + let count = try makeBlasIndex(hiddenSize, label: "Decoder projection length") + let stride = try makeBlasIndex(hiddenStride, label: "Decoder projection stride") + cblas_scopy(count, startPtr, stride, destPtr, destStrideCblas) + + return out + } + + /// Extract MLMultiArray feature value with error handling. + private func extractFeatureValue( + from output: MLFeatureProvider, key: String, errorMessage: String + ) throws + -> MLMultiArray + { + guard let value = output.featureValue(for: key)?.multiArrayValue else { + throw ASRError.processingFailed(errorMessage) + } + return value + } +} diff --git a/Sources/FluidAudio/Shared/ASRConstants.swift b/Sources/FluidAudio/Shared/ASRConstants.swift index 5a78de668..68c3b17e4 100644 --- a/Sources/FluidAudio/Shared/ASRConstants.swift +++ b/Sources/FluidAudio/Shared/ASRConstants.swift @@ -33,6 +33,18 @@ public enum ASRConstants { /// WER threshold for detailed error analysis in benchmarks public static let highWERThreshold: Double = 0.15 + /// Punctuation token IDs (period, question mark, exclamation mark) + public static let punctuationTokens: [Int] = [7883, 7952, 7948] + + /// Standard overlap in encoder frames (2.0s = 25 frames at 0.08s per frame) + public static let standardOverlapFrames: Int = 25 + + /// Minimum confidence score (for empty or very uncertain transcriptions) + public static let minConfidence: Float = 0.1 + + /// Maximum confidence score (perfect confidence) + public static let maxConfidence: Float = 1.0 + /// Calculate encoder frames from audio samples using proper ceiling division /// - Parameter samples: Number of audio samples /// - Returns: Number of encoder frames diff --git a/Tests/FluidAudioTests/ASR/Parakeet/Decoder/TdtRefactoredComponentsTests.swift b/Tests/FluidAudioTests/ASR/Parakeet/Decoder/TdtRefactoredComponentsTests.swift new file mode 100644 index 000000000..21b9387db --- /dev/null +++ b/Tests/FluidAudioTests/ASR/Parakeet/Decoder/TdtRefactoredComponentsTests.swift @@ -0,0 +1,380 @@ +import CoreML +import Foundation +import XCTest + +@testable import FluidAudio + +/// Tests for refactored TDT decoder components. +final class TdtRefactoredComponentsTests: XCTestCase { + + // MARK: - TdtFrameNavigation Tests + + func testCalculateInitialTimeIndicesFirstChunk() { + // First chunk with no timeJump + let result = TdtFrameNavigation.calculateInitialTimeIndices( + timeJump: nil, + contextFrameAdjustment: 0 + ) + XCTAssertEqual(result, 0, "First chunk should start at 0") + } + + func testCalculateInitialTimeIndicesFirstChunkWithContext() { + // First chunk with context adjustment + let result = TdtFrameNavigation.calculateInitialTimeIndices( + timeJump: nil, + contextFrameAdjustment: 5 + ) + XCTAssertEqual(result, 5, "First chunk should start at context adjustment") + } + + func testCalculateInitialTimeIndicesStreamingContinuation() { + // Normal streaming continuation + let result = TdtFrameNavigation.calculateInitialTimeIndices( + timeJump: 10, + contextFrameAdjustment: -5 + ) + XCTAssertEqual(result, 5, "Should sum timeJump and context adjustment") + } + + func testCalculateInitialTimeIndicesSpecialOverlapCase() { + // Special case: decoder finished exactly at boundary with overlap + let result = TdtFrameNavigation.calculateInitialTimeIndices( + timeJump: 0, + contextFrameAdjustment: 0 + ) + XCTAssertEqual( + result, + ASRConstants.standardOverlapFrames, + "Should skip standard overlap frames" + ) + } + + func testCalculateInitialTimeIndicesNegativeResult() { + // Should clamp negative results to 0 + let result = TdtFrameNavigation.calculateInitialTimeIndices( + timeJump: -10, + contextFrameAdjustment: 5 + ) + XCTAssertEqual(result, 0, "Should clamp negative results to 0") + } + + func testInitializeNavigationState() { + let (effectiveLength, safeIndices, lastTimestep, activeMask) = + TdtFrameNavigation.initializeNavigationState( + timeIndices: 10, + encoderSequenceLength: 100, + actualAudioFrames: 80 + ) + + XCTAssertEqual(effectiveLength, 80, "Should use minimum of encoder and audio frames") + XCTAssertEqual(safeIndices, 10, "Safe indices should be clamped to valid range") + XCTAssertEqual(lastTimestep, 79, "Last timestep is effectiveLength - 1") + XCTAssertTrue(activeMask, "Active mask should be true when timeIndices < effectiveLength") + } + + func testInitializeNavigationStateOutOfBounds() { + let (_, safeIndices, _, activeMask) = TdtFrameNavigation.initializeNavigationState( + timeIndices: 100, + encoderSequenceLength: 80, + actualAudioFrames: 80 + ) + + XCTAssertEqual(safeIndices, 79, "Should clamp to effectiveLength - 1") + XCTAssertFalse(activeMask, "Active mask should be false when timeIndices >= effectiveLength") + } + + func testCalculateFinalTimeJumpLastChunk() { + let result = TdtFrameNavigation.calculateFinalTimeJump( + currentTimeIndices: 100, + effectiveSequenceLength: 80, + isLastChunk: true + ) + XCTAssertNil(result, "Last chunk should return nil") + } + + func testCalculateFinalTimeJumpStreamingChunk() { + let result = TdtFrameNavigation.calculateFinalTimeJump( + currentTimeIndices: 100, + effectiveSequenceLength: 80, + isLastChunk: false + ) + XCTAssertEqual(result, 20, "Should return offset from chunk boundary") + } + + func testCalculateFinalTimeJumpNegativeOffset() { + let result = TdtFrameNavigation.calculateFinalTimeJump( + currentTimeIndices: 50, + effectiveSequenceLength: 80, + isLastChunk: false + ) + XCTAssertEqual(result, -30, "Should handle negative offsets") + } + + // MARK: - TdtDurationMapping Tests + + func testMapDurationBinValidIndices() throws { + let v3Bins = [1, 2, 3, 4, 5] + XCTAssertEqual(try TdtDurationMapping.mapDurationBin(0, durationBins: v3Bins), 1) + XCTAssertEqual(try TdtDurationMapping.mapDurationBin(1, durationBins: v3Bins), 2) + XCTAssertEqual(try TdtDurationMapping.mapDurationBin(2, durationBins: v3Bins), 3) + XCTAssertEqual(try TdtDurationMapping.mapDurationBin(3, durationBins: v3Bins), 4) + XCTAssertEqual(try TdtDurationMapping.mapDurationBin(4, durationBins: v3Bins), 5) + } + + func testMapDurationBinCustomMapping() throws { + let customBins = [1, 1, 2, 3, 5, 8] + XCTAssertEqual(try TdtDurationMapping.mapDurationBin(0, durationBins: customBins), 1) + XCTAssertEqual(try TdtDurationMapping.mapDurationBin(3, durationBins: customBins), 3) + XCTAssertEqual(try TdtDurationMapping.mapDurationBin(5, durationBins: customBins), 8) + } + + func testMapDurationBinOutOfRange() { + let v3Bins = [1, 2, 3, 4, 5] + XCTAssertThrowsError(try TdtDurationMapping.mapDurationBin(5, durationBins: v3Bins)) { error in + guard let asrError = error as? ASRError else { + XCTFail("Expected ASRError") + return + } + if case .processingFailed(let message) = asrError { + XCTAssertTrue(message.contains("Duration bin index out of range")) + } else { + XCTFail("Expected processingFailed error") + } + } + } + + func testMapDurationBinNegativeIndex() { + let v3Bins = [1, 2, 3, 4, 5] + XCTAssertThrowsError(try TdtDurationMapping.mapDurationBin(-1, durationBins: v3Bins)) + } + + func testClampProbabilityValidRange() { + XCTAssertEqual(TdtDurationMapping.clampProbability(0.5), 0.5, accuracy: 0.0001) + XCTAssertEqual(TdtDurationMapping.clampProbability(0.0), 0.0, accuracy: 0.0001) + XCTAssertEqual(TdtDurationMapping.clampProbability(1.0), 1.0, accuracy: 0.0001) + } + + func testClampProbabilityBelowRange() { + XCTAssertEqual(TdtDurationMapping.clampProbability(-0.5), 0.0, accuracy: 0.0001) + XCTAssertEqual(TdtDurationMapping.clampProbability(-100.0), 0.0, accuracy: 0.0001) + } + + func testClampProbabilityAboveRange() { + XCTAssertEqual(TdtDurationMapping.clampProbability(1.5), 1.0, accuracy: 0.0001) + XCTAssertEqual(TdtDurationMapping.clampProbability(100.0), 1.0, accuracy: 0.0001) + } + + func testClampProbabilityNonFinite() { + XCTAssertEqual(TdtDurationMapping.clampProbability(.nan), 0.0, accuracy: 0.0001) + XCTAssertEqual(TdtDurationMapping.clampProbability(.infinity), 0.0, accuracy: 0.0001) + XCTAssertEqual(TdtDurationMapping.clampProbability(-.infinity), 0.0, accuracy: 0.0001) + } + + // MARK: - TdtJointDecision Tests + + func testJointDecisionCreation() { + let decision = TdtJointDecision( + token: 42, + probability: 0.95, + durationBin: 3 + ) + + XCTAssertEqual(decision.token, 42) + XCTAssertEqual(decision.probability, 0.95, accuracy: 0.0001) + XCTAssertEqual(decision.durationBin, 3) + } + + func testJointDecisionWithNegativeValues() { + let decision = TdtJointDecision( + token: -1, + probability: 0.0, + durationBin: 0 + ) + + XCTAssertEqual(decision.token, -1) + XCTAssertEqual(decision.probability, 0.0, accuracy: 0.0001) + XCTAssertEqual(decision.durationBin, 0) + } + + // MARK: - TdtJointInputProvider Tests + + func testJointInputProviderFeatureNames() throws { + let encoderArray = try MLMultiArray(shape: [1, 640, 1], dataType: .float32) + let decoderArray = try MLMultiArray(shape: [1, 640, 1], dataType: .float32) + + let provider = ReusableJointInputProvider( + encoderStep: encoderArray, + decoderStep: decoderArray + ) + + XCTAssertEqual(provider.featureNames, ["encoder_step", "decoder_step"]) + } + + func testJointInputProviderFeatureValues() throws { + let encoderArray = try MLMultiArray(shape: [1, 640, 1], dataType: .float32) + let decoderArray = try MLMultiArray(shape: [1, 640, 1], dataType: .float32) + + let provider = ReusableJointInputProvider( + encoderStep: encoderArray, + decoderStep: decoderArray + ) + + let encoderFeature = provider.featureValue(for: "encoder_step") + let decoderFeature = provider.featureValue(for: "decoder_step") + + XCTAssertNotNil(encoderFeature) + XCTAssertNotNil(decoderFeature) + XCTAssertIdentical(encoderFeature?.multiArrayValue, encoderArray) + XCTAssertIdentical(decoderFeature?.multiArrayValue, decoderArray) + } + + func testJointInputProviderInvalidFeatureName() throws { + let encoderArray = try MLMultiArray(shape: [1, 640, 1], dataType: .float32) + let decoderArray = try MLMultiArray(shape: [1, 640, 1], dataType: .float32) + + let provider = ReusableJointInputProvider( + encoderStep: encoderArray, + decoderStep: decoderArray + ) + + let invalidFeature = provider.featureValue(for: "invalid_feature") + XCTAssertNil(invalidFeature, "Should return nil for invalid feature name") + } + + // MARK: - TdtModelInference Tests + + func testNormalizeDecoderProjectionAlreadyNormalized() throws { + // Input already in [1, 640, 1] format + let input = try MLMultiArray(shape: [1, 640, 1], dataType: .float32) + for i in 0..<640 { + input[[0, i, 0] as [NSNumber]] = NSNumber(value: Float(i)) + } + + let inference = TdtModelInference() + let normalized = try inference.normalizeDecoderProjection(input) + + XCTAssertEqual(normalized.shape.map { $0.intValue }, [1, 640, 1]) + + // Verify data was copied correctly + for i in 0..<640 { + let expected = Float(i) + let actual = normalized[[0, i, 0] as [NSNumber]].floatValue + XCTAssertEqual(actual, expected, accuracy: 0.0001) + } + } + + func testNormalizeDecoderProjectionTranspose() throws { + // Input in [1, 1, 640] format (needs transpose) + let input = try MLMultiArray(shape: [1, 1, 640], dataType: .float32) + for i in 0..<640 { + input[[0, 0, i] as [NSNumber]] = NSNumber(value: Float(i)) + } + + let inference = TdtModelInference() + let normalized = try inference.normalizeDecoderProjection(input) + + XCTAssertEqual(normalized.shape.map { $0.intValue }, [1, 640, 1]) + + // Verify data was copied correctly + for i in 0..<640 { + let expected = Float(i) + let actual = normalized[[0, i, 0] as [NSNumber]].floatValue + XCTAssertEqual(actual, expected, accuracy: 0.0001) + } + } + + func testNormalizeDecoderProjectionWithDestination() throws { + // Input in [1, 1, 640] format + let input = try MLMultiArray(shape: [1, 1, 640], dataType: .float32) + for i in 0..<640 { + input[[0, 0, i] as [NSNumber]] = NSNumber(value: Float(i * 2)) + } + + // Pre-allocate destination + let destination = try MLMultiArray(shape: [1, 640, 1], dataType: .float32) + + let inference = TdtModelInference() + let normalized = try inference.normalizeDecoderProjection(input, into: destination) + + XCTAssertIdentical(normalized, destination, "Should reuse destination array") + + // Verify data was copied correctly + for i in 0..<640 { + let expected = Float(i * 2) + let actual = normalized[[0, i, 0] as [NSNumber]].floatValue + XCTAssertEqual(actual, expected, accuracy: 0.0001) + } + } + + func testNormalizeDecoderProjectionInvalidRank() throws { + // Input with wrong rank + let input = try MLMultiArray(shape: [640], dataType: .float32) + + let inference = TdtModelInference() + XCTAssertThrowsError(try inference.normalizeDecoderProjection(input)) { error in + guard let asrError = error as? ASRError else { + XCTFail("Expected ASRError") + return + } + if case .processingFailed(let message) = asrError { + XCTAssertTrue(message.contains("Invalid decoder projection rank")) + } else { + XCTFail("Expected processingFailed error") + } + } + } + + func testNormalizeDecoderProjectionInvalidBatchSize() throws { + // Input with batch size != 1 + let input = try MLMultiArray(shape: [2, 640, 1], dataType: .float32) + + let inference = TdtModelInference() + XCTAssertThrowsError(try inference.normalizeDecoderProjection(input)) { error in + guard let asrError = error as? ASRError else { + XCTFail("Expected ASRError") + return + } + if case .processingFailed(let message) = asrError { + XCTAssertTrue(message.contains("Unsupported decoder batch dimension")) + } else { + XCTFail("Expected processingFailed error") + } + } + } + + func testNormalizeDecoderProjectionInvalidHiddenSize() throws { + // Input with wrong hidden size + let input = try MLMultiArray(shape: [1, 128, 1], dataType: .float32) + + let inference = TdtModelInference() + XCTAssertThrowsError(try inference.normalizeDecoderProjection(input)) { error in + guard let asrError = error as? ASRError else { + XCTFail("Expected ASRError") + return + } + if case .processingFailed(let message) = asrError { + XCTAssertTrue(message.contains("Decoder projection hidden size mismatch")) + } else { + XCTFail("Expected processingFailed error") + } + } + } + + func testNormalizeDecoderProjectionInvalidTimeAxis() throws { + // Input with time axis != 1 + let input = try MLMultiArray(shape: [1, 640, 2], dataType: .float32) + + let inference = TdtModelInference() + XCTAssertThrowsError(try inference.normalizeDecoderProjection(input)) { error in + guard let asrError = error as? ASRError else { + XCTFail("Expected ASRError") + return + } + if case .processingFailed(let message) = asrError { + XCTAssertTrue(message.contains("Decoder projection time axis must be 1")) + } else { + XCTFail("Expected processingFailed error") + } + } + } +}