diff --git a/Documentation/API.md b/Documentation/API.md index b8a46e237..95ab43c57 100644 --- a/Documentation/API.md +++ b/Documentation/API.md @@ -83,7 +83,7 @@ Use `OfflineDiarizerManager` when you need offline DER parity or want to run the **Speaker Enrollment:** `enrollSpeaker(withAudio:sourceSampleRate:named:...)` feeds known-speaker audio before streaming to label a slot. -**Lifecycle:** `reset()` clears streaming state but keeps the model loaded. `cleanup()` releases everything. +**Lifecycle:** `finalizeSession()` flushes trailing context so the last true frame becomes finalized. `reset()` clears streaming state but keeps the model loaded. `cleanup()` releases everything. --- diff --git a/Documentation/Diarization/DiarizerTimeline.md b/Documentation/Diarization/DiarizerTimeline.md index 54933b7fa..c80a639ba 100644 --- a/Documentation/Diarization/DiarizerTimeline.md +++ b/Documentation/Diarization/DiarizerTimeline.md @@ -157,4 +157,4 @@ When `timeline.addChunk(_:)` is called internally by the diarizer: 3. It iterates over all `DiarizerSpeaker` tracks, evaluating the boundaries (using `onsetThreshold` and `offsetThreshold`) to grow existing segments or spawn new ones. 4. Tentative segments are cleared and rebuilt from the trailing `tentativePredictions` array during every streaming tick. -When the stream naturally finishes, the `Diarizer` automatically invokes `timeline.finalize()`, which flushes any remaining tentative segments up to finalized status and applies the `minFramesOn` deletion rules. +When the stream naturally finishes, call `Diarizer.finalizeSession()`. The diarizer flushes trailing context first, then invokes `timeline.finalize()`, which promotes any remaining tentative segments to finalized status and applies the `minFramesOn` deletion rules. diff --git a/Documentation/Diarization/LS-EEND.md b/Documentation/Diarization/LS-EEND.md index 3e357b326..50d74c8ab 100644 --- a/Documentation/Diarization/LS-EEND.md +++ b/Documentation/Diarization/LS-EEND.md @@ -226,11 +226,14 @@ if let update = try diarizer.process() { // Convenience: add + process in one call if let update = try diarizer.process(samples: audioChunk) { ... } -// Flush remaining frames at end of stream +// Flush remaining frames at the end of a stream try diarizer.finalizeSession() let finalTimeline = diarizer.timeline ``` +Notes: +- `finalizeSession()` flushes the remaining audio by padding the end with silence. + ### Speaker Enrollment Use speaker enrollment to warm LS-EEND with a known speaker before the live stream starts. Enrollment keeps the active streaming session, resets the visible timeline back to frame 0, and preserves the speaker name inside the `DiarizerTimeline`. @@ -288,6 +291,7 @@ Real-world integration testing with 4-speaker audio reveals specific enrollment ### Lifecycle ```swift +try diarizer.finalizeSession() // Flush trailing context before reading final output diarizer.reset() // Reset streaming state for a new audio stream (keeps model loaded) diarizer.cleanup() // Release all resources including the loaded model ``` diff --git a/Documentation/Diarization/Sortformer.md b/Documentation/Diarization/Sortformer.md index 0c9812411..14733d3f1 100644 --- a/Documentation/Diarization/Sortformer.md +++ b/Documentation/Diarization/Sortformer.md @@ -368,6 +368,10 @@ public struct SortformerSegment { │ └─→ timeline.addChunk(result) │ │ └─→ Update segments per speaker │ │ │ +│ 3. finalizeSession() │ +│ └─→ pad trailing silence until last true frame is emitted │ +│ └─→ timeline.finalize() │ +│ │ └────────────────────────────────────────────────────────────────┘ ``` @@ -444,6 +448,8 @@ audioEngine.installTap { buffer in updateSpeakerDisplay(diarizer.timeline) } } + +try diarizer.finalizeSession() ``` ### Batch Processing @@ -466,6 +472,8 @@ for (index, speaker) in timeline.speakers { } ``` +`finalizeSession()` is only needed for streaming mode. It pads enough trailing silence to flush Sortformer's right-context preview frames, then finalizes the timeline so `numTentativeFrames == 0`. + ### Speaker Enrollment Use speaker enrollment to warm Sortformer with known speakers before live audio starts. Enrollment preserves the speaker cache / FIFO state, resets the visible timeline, and keeps the speaker name in the `DiarizerTimeline`. diff --git a/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift b/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift index a75553c82..ef11047f4 100644 --- a/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift +++ b/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift @@ -506,6 +506,9 @@ public struct DiarizerSegment: Sendable, Identifiable, Comparable, Equatable { /// Duration of one frame in seconds public let frameDurationSeconds: Float + /// Confidence in this speech segment (average speech probability from the diarizer) + public var confidence: Float = 0.0 + /// Start time in seconds public var startTime: Float { Float(startFrame) * frameDurationSeconds } @@ -523,7 +526,8 @@ public struct DiarizerSegment: Sendable, Identifiable, Comparable, Equatable { startFrame: Int, endFrame: Int, finalized: Bool = true, - frameDurationSeconds: Float + frameDurationSeconds: Float, + confidence: Float = 0 ) { self.id = UUID() self.speakerIndex = speakerIndex @@ -531,6 +535,7 @@ public struct DiarizerSegment: Sendable, Identifiable, Comparable, Equatable { self.endFrame = endFrame self.isFinalized = finalized self.frameDurationSeconds = frameDurationSeconds + self.confidence = confidence } public init( @@ -538,7 +543,8 @@ public struct DiarizerSegment: Sendable, Identifiable, Comparable, Equatable { startTime: Float, endTime: Float, finalized: Bool = true, - frameDurationSeconds: Float + frameDurationSeconds: Float, + confidence: Float = 0 ) { self.id = UUID() self.speakerIndex = speakerIndex @@ -546,6 +552,7 @@ public struct DiarizerSegment: Sendable, Identifiable, Comparable, Equatable { self.endFrame = Int(round(endTime / frameDurationSeconds)) self.isFinalized = finalized self.frameDurationSeconds = frameDurationSeconds + self.confidence = confidence } /// Check if this overlaps with another segment @@ -639,6 +646,13 @@ public struct DiarizerChunkResult: Sendable { /// Generalizes `SortformerTimeline` for any frame-based diarizer. Works with /// both Sortformer (fixed 4 speakers) and LS-EEND (variable speaker count). public final class DiarizerTimeline { + private struct ClosedSegmentStats { + var start: Int + var end: Int + var activitySum: Float + var activeFrameCount: Int + } + public enum KeptOnReset { case nothing case namedSpeakers @@ -650,15 +664,21 @@ public final class DiarizerTimeline { private struct StreamingState { var startFrame: Int var isSpeaking: Bool - var lastSegment: (start: Int, end: Int) + var activitySum: Float + var activeFrameCount: Int + var lastSegment: ClosedSegmentStats? init( startFrame: Int = 0, isSpeaking: Bool = false, - lastSegment: (start: Int, end: Int) = (-1, -1) + activitySum: Float = 0, + activeFrameCount: Int = 0, + lastSegment: ClosedSegmentStats? = nil ) { self.startFrame = startFrame self.isSpeaking = isSpeaking + self.activitySum = activitySum + self.activeFrameCount = activeFrameCount self.lastSegment = lastSegment } } @@ -688,6 +708,11 @@ public final class DiarizerTimeline { queue.sync { _tentativePredictions.count / speakerCapacity } } + /// Total number of frames (finalized + tentative) + public var numFrames: Int { + queue.sync { _numFinalizedFrames + _tentativePredictions.count / speakerCapacity } + } + /// Speakers in the timeline public var speakers: [Int: DiarizerSpeaker] { get { queue.sync { _speakers } } @@ -706,6 +731,7 @@ public final class DiarizerTimeline { } } + /// Whether the timeline has any segments public var hasSegments: Bool { speakers.values.contains(where: \.hasSegments) } @@ -715,6 +741,16 @@ public final class DiarizerTimeline { Float(numFinalizedFrames) * config.frameDurationSeconds } + /// Duration of tentative predictions in seconds + public var tentativeDuration: Float { + Float(numTentativeFrames) * config.frameDurationSeconds + } + + /// Duration of all predictions (finalized + tentative) in seconds + public var duration: Float { + Float(numFrames) * config.frameDurationSeconds + } + /// Maximum number of speakers public var speakerCapacity: Int { config.numSpeakers @@ -1103,42 +1139,64 @@ public final class DiarizerTimeline { var start = state.startFrame var speaking = state.isSpeaking + var activitySum = state.activitySum + var activeFrameCount = state.activeFrameCount var lastSegment = state.lastSegment var wasLastSegmentFinal = isFinalized for i in 0..= offset { + if activity >= offset { + activitySum += activity + activeFrameCount += 1 continue } speaking = false let end = frameOffset + i + padOffset - guard end - start > minFramesOn else { continue } + guard end - start > minFramesOn else { + activitySum = 0 + activeFrameCount = 0 + continue + } wasLastSegmentFinal = isFinalized && (end < tentativeStartFrame) + let confidence = activeFrameCount > 0 ? (activitySum / Float(activeFrameCount)) : 0 let newSegment = DiarizerSegment( speakerIndex: speakerIndex, startFrame: start, endFrame: end, finalized: wasLastSegmentFinal, - frameDurationSeconds: frameDuration + frameDurationSeconds: frameDuration, + confidence: confidence ) provideSpeaker(forSlot: speakerIndex).append(newSegment) - lastSegment = (start, end) + lastSegment = ClosedSegmentStats( + start: start, + end: end, + activitySum: activitySum, + activeFrameCount: activeFrameCount + ) + activitySum = 0 + activeFrameCount = 0 - } else if predictions[index] > onset { + } else if activity > onset { start = max(0, frameOffset + i - padOnset) speaking = true + activitySum = activity + activeFrameCount = 1 - if start - lastSegment.end <= minFramesOff { + if let lastSegment, start - lastSegment.end <= minFramesOff { start = lastSegment.start + activitySum += lastSegment.activitySum + activeFrameCount += lastSegment.activeFrameCount _speakers[speakerIndex]?.popLast(fromFinalized: wasLastSegmentFinal) } } @@ -1147,18 +1205,22 @@ public final class DiarizerTimeline { if isFinalized { states[speakerIndex].startFrame = start states[speakerIndex].isSpeaking = speaking + states[speakerIndex].activitySum = activitySum + states[speakerIndex].activeFrameCount = activeFrameCount states[speakerIndex].lastSegment = lastSegment } if addTrailingTentative { let end = frameOffset + numFrames + padOffset if speaking && (end > start) { + let confidence = activeFrameCount > 0 ? (activitySum / Float(activeFrameCount)) : 0 let newSegment = DiarizerSegment( speakerIndex: speakerIndex, startFrame: start, endFrame: end, finalized: false, - frameDurationSeconds: frameDuration + frameDurationSeconds: frameDuration, + confidence: confidence ) provideSpeaker(forSlot: speakerIndex).appendTentative(newSegment) } diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift index 0a9aaec94..b74affa83 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift @@ -14,44 +14,32 @@ public final class LSEENDDiarizer: Diarizer { /// Accumulated results public var timeline: DiarizerTimeline { - lock.lock() - defer { lock.unlock() } - return _timeline + lock.withLock { return _timeline } } /// Whether the processor is ready for processing public var isAvailable: Bool { - lock.lock() - defer { lock.unlock() } - return _engine != nil + lock.withLock { return _engine != nil } } /// Number of confirmed frames processed so far public var numFramesProcessed: Int { - lock.lock() - defer { lock.unlock() } - return _numFramesProcessed + lock.withLock { return _numFramesProcessed } } /// Model's target sample rate in Hz (e.g., 8000) public var targetSampleRate: Int? { - lock.lock() - defer { lock.unlock() } - return _engine?.targetSampleRate + lock.withLock { return _engine?.targetSampleRate } } /// Output frame rate in Hz (e.g., 10.0) public var modelFrameHz: Double? { - lock.lock() - defer { lock.unlock() } - return _engine?.modelFrameHz + lock.withLock { return _engine?.modelFrameHz } } /// Number of real speaker tracks (excluding boundary tracks) public var numSpeakers: Int? { - lock.lock() - defer { lock.unlock() } - return _engine?.metadata.realOutputDim + lock.withLock { return _engine?.metadata.realOutputDim } } // MARK: - Additional Properties @@ -61,30 +49,22 @@ public final class LSEENDDiarizer: Diarizer { /// Post-processing configuration public var timelineConfig: DiarizerTimelineConfig { - lock.lock() - defer { lock.unlock() } - return _timeline.config + lock.withLock { return _timeline.config } } /// Streaming latency in seconds public var streamingLatencySeconds: Double? { - lock.lock() - defer { lock.unlock() } - return _engine?.streamingLatencySeconds + lock.withLock { return _engine?.streamingLatencySeconds } } /// Total speaker slots in model output (including boundary tracks) public var decodeMaxSpeakers: Int? { - lock.lock() - defer { lock.unlock() } - return _engine?.decodeMaxSpeakers + lock.withLock { return _engine?.decodeMaxSpeakers } } /// Whether a streaming session is currently active. var hasActiveSession: Bool { - lock.lock() - defer { lock.unlock() } - return _session != nil + lock.withLock { return _session != nil } } // MARK: - Private State @@ -176,39 +156,37 @@ public final class LSEENDDiarizer: Diarizer { let engine = try LSEENDInferenceHelper(descriptor: descriptor, computeUnits: computeUnits) let melSpectrogram = Self.createMelSpectrogram(featureConfig: engine.featureConfig) - lock.lock() - defer { lock.unlock() } - - updateTimelineConfig(engine: engine) - _engine = engine - _melSpectrogram = melSpectrogram - _timeline = DiarizerTimeline(config: _timelineConfig) - _session = nil - resetBuffersLocked() + lock.withLock { + updateTimelineConfig(engine: engine) + _engine = engine + _melSpectrogram = melSpectrogram + _timeline = DiarizerTimeline(config: _timelineConfig) + _session = nil + resetBuffersLocked() - logger.info( - "Initialized LS-EEND \(descriptor.variant.rawValue): " - + "\(engine.metadata.realOutputDim) speakers, " - + "\(String(format: "%.1f", engine.modelFrameHz)) Hz, " - + "\(String(format: "%.2f", engine.streamingLatencySeconds))s latency" - ) + logger.info( + "Initialized LS-EEND \(descriptor.variant.rawValue): " + + "\(engine.metadata.realOutputDim) speakers, " + + "\(String(format: "%.1f", engine.modelFrameHz)) Hz, " + + "\(String(format: "%.2f", engine.streamingLatencySeconds))s latency" + ) + } } /// Initialize with a pre-loaded engine. public func initialize(engine: LSEENDInferenceHelper) { let melSpectrogram = Self.createMelSpectrogram(featureConfig: engine.featureConfig) - lock.lock() - defer { lock.unlock() } - - updateTimelineConfig(engine: engine) - _engine = engine - _melSpectrogram = melSpectrogram - _timeline = DiarizerTimeline(config: _timelineConfig) - _session = nil - resetBuffersLocked() + lock.withLock { + updateTimelineConfig(engine: engine) + _engine = engine + _melSpectrogram = melSpectrogram + _timeline = DiarizerTimeline(config: _timelineConfig) + _session = nil + resetBuffersLocked() - logger.info("Initialized LS-EEND with pre-loaded engine") + logger.info("Initialized LS-EEND with pre-loaded engine") + } } // MARK: - Speaker Priming @@ -268,103 +246,101 @@ public final class LSEENDDiarizer: Diarizer { named name: String?, overwritingAssignedSpeakerName overwriteAssignedSpeakerName: Bool ) throws -> DiarizerSpeaker? { - let description: String = name.map { "named '\($0)'" } ?? "(no name)" - - lock.lock() - defer { lock.unlock() } - - guard let engine = _engine else { - throw LSEENDError.modelPredictionFailed("LS-EEND processor not initialized. Call initialize() first.") - } + try lock.withLock { + let description: String = name.map { "named '\($0)'" } ?? "(no name)" + guard let engine = _engine else { + throw LSEENDError.modelPredictionFailed("LS-EEND processor not initialized. Call initialize() first.") + } - let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) ?? samples - guard !normalized.isEmpty else { - logger.warning("Failed to enroll speaker \(description) because no speech detected") - return nil - } + let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) ?? samples + guard !normalized.isEmpty else { + logger.warning("Failed to enroll speaker \(description) because no speech detected") + return nil + } - if _timeline.hasSegments { - logger.warning("Trying to enroll a speaker while timeline has segments; timeline will be reset") - } + if _timeline.hasSegments { + logger.warning("Trying to enroll a speaker while timeline has segments; timeline will be reset") + } - _timeline.reset(keepingSpeakers: true) - var occupiedIndices = Set(_timeline.speakers.keys) - _numFramesProcessed = 0 - _visibleStartFrameOffset = 0 - pendingAudio.removeAll(keepingCapacity: true) + _timeline.reset(keepingSpeakers: true) + var occupiedIndices = Set(_timeline.speakers.keys) + _numFramesProcessed = 0 + _visibleStartFrameOffset = 0 + pendingAudio.removeAll(keepingCapacity: true) - if _session == nil { - _session = try engine.createSession( - inputSampleRate: engine.targetSampleRate, melSpectrogram: _melSpectrogram!) - } - guard let session = _session else { - return nil - } + if _session == nil { + _session = try engine.createSession( + inputSampleRate: engine.targetSampleRate, melSpectrogram: _melSpectrogram!) + } + guard let session = _session else { + return nil + } - let update = try session.pushAudio(normalized) - let didProcess = update.map { !$0.probabilities.isEmpty || !$0.previewProbabilities.isEmpty } ?? false + let update = try session.pushAudio(normalized) + let didProcess = update.map { !$0.probabilities.isEmpty || !$0.previewProbabilities.isEmpty } ?? false - guard didProcess else { - let minimumSeconds = engine.streamingLatencySeconds - logger.warning( - "Failed to enroll speaker \(description): not enough audio was provided. " - + "Please provide at least \(String(format: "%.2f", minimumSeconds)) seconds of speech." - ) - return nil - } + guard didProcess else { + let minimumSeconds = engine.streamingLatencySeconds + logger.warning( + "Failed to enroll speaker \(description): not enough audio was provided. " + + "Please provide at least \(String(format: "%.2f", minimumSeconds)) seconds of speech." + ) + return nil + } - if let update { - let numSpeakers = engine.metadata.realOutputDim - let result = DiarizerChunkResult( - startFrame: max(0, update.startFrame - _visibleStartFrameOffset), - finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers), - finalizedFrameCount: update.probabilities.rows, - tentativePredictions: flattenRowMajor(update.previewProbabilities, numSpeakers: numSpeakers), - tentativeFrameCount: update.previewProbabilities.rows - ) - _numFramesProcessed += result.finalizedFrameCount - _ = try _timeline.addChunk(result) - } + if let update { + let numSpeakers = engine.metadata.realOutputDim + let result = DiarizerChunkResult( + startFrame: max(0, update.startFrame - _visibleStartFrameOffset), + finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers), + finalizedFrameCount: update.probabilities.rows, + tentativePredictions: flattenRowMajor(update.previewProbabilities, numSpeakers: numSpeakers), + tentativeFrameCount: update.previewProbabilities.rows + ) + _numFramesProcessed += result.finalizedFrameCount + _ = try _timeline.addChunk(result) + } - let speaker = _timeline.speakers.values.max { $0.numSpeechFrames < $1.numSpeechFrames } - let enrolledSpeaker: DiarizerSpeaker? - if let speaker, speaker.hasSegments { - if let oldName = speaker.name { - guard overwriteAssignedSpeakerName else { + let speaker = _timeline.speakers.values.max { $0.numSpeechFrames < $1.numSpeechFrames } + let enrolledSpeaker: DiarizerSpeaker? + if let speaker, speaker.hasSegments { + if let oldName = speaker.name { + guard overwriteAssignedSpeakerName else { + logger.warning( + "Failed to enroll speaker \(description): diarizer matched existing speaker '\(oldName)' " + + "at index \(speaker.index) and overwritingAssignedSpeakerName=false" + ) + _visibleStartFrameOffset = session.snapshot().probabilities.rows + _numFramesProcessed = 0 + _timeline.reset(keepingSpeakersWhere: { occupiedIndices.contains($0.index) }) + pendingAudio.removeAll(keepingCapacity: true) + return nil + } logger.warning( - "Failed to enroll speaker \(description): diarizer matched existing speaker '\(oldName)' " - + "at index \(speaker.index) and overwritingAssignedSpeakerName=false" + "Newly-enrolled speaker \(description) will overwrite the old one named \(oldName) at index \(speaker.index)" ) - _visibleStartFrameOffset = session.snapshot().probabilities.rows - _numFramesProcessed = 0 - _timeline.reset(keepingSpeakersWhere: { occupiedIndices.contains($0.index) }) - pendingAudio.removeAll(keepingCapacity: true) - return nil } - logger.warning( - "Newly-enrolled speaker \(description) will overwrite the old one named \(oldName) at index \(speaker.index)" - ) + speaker.name = name + occupiedIndices.insert(speaker.index) + enrolledSpeaker = speaker + } else { + logger.warning("Failed to enroll speaker \(description) because no speech detected") + enrolledSpeaker = nil } - speaker.name = name - occupiedIndices.insert(speaker.index) - enrolledSpeaker = speaker - } else { - logger.warning("Failed to enroll speaker \(description) because no speech detected") - enrolledSpeaker = nil - } - _visibleStartFrameOffset = session.snapshot().probabilities.rows - _numFramesProcessed = 0 - _timeline.reset(keepingSpeakersWhere: { occupiedIndices.contains($0.index) }) - pendingAudio.removeAll(keepingCapacity: true) + _visibleStartFrameOffset = session.snapshot().probabilities.rows + _numFramesProcessed = 0 + _timeline.reset(keepingSpeakersWhere: { occupiedIndices.contains($0.index) }) + pendingAudio.removeAll(keepingCapacity: true) - logger.info( - "Enrolled speaker \(description) with \(normalized.count) samples " - + "(\(String(format: "%.1f", Float(normalized.count) / Float(engine.targetSampleRate)))s), " - + "visible offset=\(_visibleStartFrameOffset)" - ) + logger.info( + "Enrolled speaker \(description) with \(normalized.count) samples " + + "(\(String(format: "%.1f", Float(normalized.count) / Float(engine.targetSampleRate)))s), " + + "visible offset=\(_visibleStartFrameOffset)" + ) - return enrolledSpeaker + return enrolledSpeaker + } } // MARK: - Streaming (Diarizer Protocol) @@ -387,13 +363,12 @@ public final class LSEENDDiarizer: Diarizer { _ samples: C, sourceSampleRate: Double? = nil ) throws where C.Element == Float { - lock.lock() - defer { lock.unlock() } - - if let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) { - pendingAudio.append(contentsOf: normalized) - } else { - pendingAudio.append(contentsOf: samples) + try lock.withLock { + if let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) { + pendingAudio.append(contentsOf: normalized) + } else { + pendingAudio.append(contentsOf: samples) + } } } @@ -401,15 +376,10 @@ public final class LSEENDDiarizer: Diarizer { /// /// - Returns: New chunk result if inference produced frames, nil otherwise public func process() throws -> DiarizerTimelineUpdate? { - lock.lock() - defer { lock.unlock() } - return try processLocked() + try lock.withLock { return try processLocked() } } - /// Process a chunk of audio in one call. - /// - /// Convenience method that combines `addAudio()` and `process()`. - /// + /// Add and process a chunk of audio in one call. /// - Parameters: /// - samples: Audio samples to process. /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the model rate. @@ -418,16 +388,15 @@ public final class LSEENDDiarizer: Diarizer { samples: C, sourceSampleRate: Double? = nil ) throws -> DiarizerTimelineUpdate? where C.Element == Float { - lock.lock() - defer { lock.unlock() } + try lock.withLock { + if let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) { + pendingAudio.append(contentsOf: normalized) + } else { + pendingAudio.append(contentsOf: samples) + } - if let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) { - pendingAudio.append(contentsOf: normalized) - } else { - pendingAudio.append(contentsOf: samples) + return try processLocked() } - - return try processLocked() } /// Internal process — caller must hold lock. @@ -641,26 +610,24 @@ public final class LSEENDDiarizer: Diarizer { /// /// Preserves the loaded model. Call `initialize()` again to change models. public func reset() { - lock.lock() - defer { lock.unlock() } - - _session = nil - _timeline.reset() - resetBuffersLocked() - logger.debug("LS-EEND state reset") + lock.withLock { + _session = nil + _timeline.reset() + resetBuffersLocked() + logger.debug("LS-EEND state reset") + } } /// Clean up all resources including the loaded model. public func cleanup() { - lock.lock() - defer { lock.unlock() } - - _engine = nil - _session = nil - _melSpectrogram = nil - _timeline.reset() - resetBuffersLocked() - logger.info("LS-EEND resources cleaned up") + lock.withLock { + _engine = nil + _session = nil + _melSpectrogram = nil + _timeline.reset() + resetBuffersLocked() + logger.info("LS-EEND resources cleaned up") + } } // MARK: - LS-EEND Specific @@ -678,6 +645,8 @@ public final class LSEENDDiarizer: Diarizer { defer { lock.unlock() } guard let engine = _engine, let session = _session else { return nil } + let numSpeakers = engine.metadata.realOutputDim + var lastResult: DiarizerChunkResult? // Flush pending audio first — clear unconditionally so failed audio isn't retained. // Using defer + direct pass avoids a CoW copy. @@ -685,7 +654,6 @@ public final class LSEENDDiarizer: Diarizer { defer { pendingAudio.removeAll(keepingCapacity: true) } let pushedUpdate = try session.pushAudio(pendingAudio) if let update = pushedUpdate { - let numSpeakers = engine.metadata.realOutputDim let flushedResult = DiarizerChunkResult( startFrame: _numFramesProcessed, finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers), @@ -695,29 +663,26 @@ public final class LSEENDDiarizer: Diarizer { ) _numFramesProcessed += flushedResult.finalizedFrameCount try _timeline.addChunk(flushedResult) + lastResult = flushedResult } } - guard let finalUpdate = try session.finalize() else { - _session = nil - _timeline.finalize() - return nil + if let finalUpdate = try session.finalize() { + let finalResult = DiarizerChunkResult( + startFrame: _numFramesProcessed, + finalizedPredictions: flattenRowMajor(finalUpdate.probabilities, numSpeakers: numSpeakers), + finalizedFrameCount: finalUpdate.probabilities.rows, + tentativePredictions: [], + tentativeFrameCount: 0 + ) + _numFramesProcessed += finalResult.finalizedFrameCount + try _timeline.addChunk(finalResult) + lastResult = finalResult } - - let numSpeakers = engine.metadata.realOutputDim - let result = DiarizerChunkResult( - startFrame: _numFramesProcessed, - finalizedPredictions: flattenRowMajor(finalUpdate.probabilities, numSpeakers: numSpeakers), - finalizedFrameCount: finalUpdate.probabilities.rows, - tentativePredictions: [], - tentativeFrameCount: 0 - ) - _numFramesProcessed += result.finalizedFrameCount - try _timeline.addChunk(result) _timeline.finalize() _session = nil - return result + return lastResult } // MARK: - Private diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift index 69afd6a36..f978b9006 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift @@ -208,17 +208,11 @@ public final class LSEENDInferenceHelper { /// - Returns: Complete inference result with logits and probabilities. public func infer(samples: [Float], sampleRate: Int) throws -> LSEENDInferenceResult { let normalizedAudio = try resampleIfNeeded(samples: samples, sampleRate: sampleRate) - let features = try offlineFeatureExtractor.extractFeatures(audio: normalizedAudio) let session = try createSession(inputSampleRate: targetSampleRate) - session.totalInputSamples = normalizedAudio.count - let committed = try session.ingestFeatures(features) - let pending = session.totalFeatureFrames - session.emittedFrames - let tail = - try pending > 0 - ? session.flushTail(from: session.state, pendingFrames: pending) : .empty(columns: decodeMaxSpeakers) - let fullLogits = committed.appendingRows(tail) - session.fullLogitChunks = fullLogits.isEmpty ? [] : [fullLogits] - session.emittedFrames = fullLogits.rows + if !normalizedAudio.isEmpty { + _ = try session.pushAudio(normalizedAudio) + } + _ = try session.finalize() return session.snapshot() } @@ -518,13 +512,45 @@ public final class LSEENDStreamingSession { guard !finalized else { return nil } - let features = try featureExtractor.finalize() - let committed = try ingestFeatures(features) + + var committedFullLogits = LSEENDMatrix.empty(columns: engine.decodeMaxSpeakers) + let targetEndFrame = Int( + round(Double(totalInputSamples) / Double(max(inputSampleRate, 1)) * engine.modelFrameHz)) + let exactPaddingSamples = try exactFinalizationPaddingSamples(targetEndFrame: targetEndFrame) + if exactPaddingSamples > 0 { + let features = try featureExtractor.pushAudio([Float](repeating: 0, count: exactPaddingSamples)) + let committed = try ingestFeatures(features) + if committed.rows > 0 { + committedFullLogits = committedFullLogits.appendingRows(committed) + } + } + + let finalFeatures = try featureExtractor.finalize() + let finalCommitted = try ingestFeatures(finalFeatures) + if finalCommitted.rows > 0 { + committedFullLogits = committedFullLogits.appendingRows(finalCommitted) + } + let pending = totalFeatureFrames - emittedFrames let tail = try pending > 0 ? flushTail(from: state, pendingFrames: pending) : .empty(columns: engine.decodeMaxSpeakers) + emittedFrames += tail.rows finalized = true - return try buildUpdate(committedFullLogits: committed.appendingRows(tail), includePreview: false) + return try buildUpdate(committedFullLogits: committedFullLogits.appendingRows(tail), includePreview: false) + } + + private func exactFinalizationPaddingSamples(targetEndFrame: Int) throws -> Int { + guard targetEndFrame > 0 else { + return 0 + } + let stableBlockSize = engine.metadata.resolvedHopLength * engine.metadata.resolvedSubsampling + let (requiredTotalSamples, overflow) = targetEndFrame.multipliedReportingOverflow(by: stableBlockSize) + guard !overflow else { + throw LSEENDError.unsupportedAudio( + "Finalization padding overflowed for \(targetEndFrame) frames at block size \(stableBlockSize)." + ) + } + return max(0, requiredTotalSamples - totalInputSamples) } /// Assembles the full inference result from all committed frames emitted so far. diff --git a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift index 2c499304b..6187e33db 100644 --- a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift +++ b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift @@ -44,6 +44,7 @@ public final class SortformerDiarizer: Diarizer { return _numFramesProcessed } private var _numFramesProcessed: Int = 0 + private var _realSamplesReceived: Int = 0 /// Configuration public let config: SortformerConfig @@ -151,6 +152,7 @@ public final class SortformerDiarizer: Diarizer { startFeat = 0 diarizerChunkIndex = 0 _numFramesProcessed = 0 + _realSamplesReceived = 0 _timeline.reset(keepingSpeakers: keepingSpeakers) featureBuffer.reserveCapacity((config.chunkMelFrames + config.coreFrames) * config.melFeatures) @@ -253,6 +255,7 @@ public final class SortformerDiarizer: Diarizer { diarizerChunkIndex = 0 audioBuffer.removeAll(keepingCapacity: true) featureBuffer.removeAll(keepingCapacity: true) + _realSamplesReceived = 0 audioBuffer.append(contentsOf: normalized) preprocessAudioToFeaturesLocked() @@ -285,6 +288,7 @@ public final class SortformerDiarizer: Diarizer { lastAudioSample = 0 audioBuffer.removeAll(keepingCapacity: true) featureBuffer.removeAll(keepingCapacity: true) + _realSamplesReceived = 0 return nil } logger.warning( @@ -306,6 +310,7 @@ public final class SortformerDiarizer: Diarizer { lastAudioSample = 0 audioBuffer.removeAll(keepingCapacity: true) featureBuffer.removeAll(keepingCapacity: true) + _realSamplesReceived = 0 logger.info( "Enrolled speaker \(description) with \(normalized.count) samples " @@ -327,6 +332,7 @@ public final class SortformerDiarizer: Diarizer { public func addAudio(_ samples: [Float]) { lock.withLock { audioBuffer.append(contentsOf: samples) + _realSamplesReceived += samples.count preprocessAudioToFeaturesLocked() } } @@ -343,6 +349,7 @@ public final class SortformerDiarizer: Diarizer { let normalized = try normalizeSamples(samples, sourceSampleRate: sourceSampleRate) lock.withLock { audioBuffer.append(contentsOf: normalized) + _realSamplesReceived += normalized.count preprocessAudioToFeaturesLocked() } } @@ -360,8 +367,10 @@ public final class SortformerDiarizer: Diarizer { if let sourceSampleRate, sourceSampleRate != Double(config.sampleRate) { let normalized = try normalizeSamples(Array(samples), sourceSampleRate: sourceSampleRate) audioBuffer.append(contentsOf: normalized) + _realSamplesReceived += normalized.count } else { audioBuffer.append(contentsOf: samples) + _realSamplesReceived += samples.count } preprocessAudioToFeaturesLocked() } @@ -378,14 +387,7 @@ public final class SortformerDiarizer: Diarizer { } } - /// Process a chunk of audio in one call. - /// - /// Convenience method that combines `addAudio()` and `process()`. - /// - /// Process a chunk of audio in one call. - /// - /// Convenience method that combines `addAudio()` and `process()`. - /// + /// Add and process a chunk of audio in one call. /// - Parameters: /// - samples: Audio samples (16kHz mono) /// - sourceSampleRate: Source audio sample rate @@ -399,8 +401,10 @@ public final class SortformerDiarizer: Diarizer { if let sourceSampleRate, sourceSampleRate != Double(config.sampleRate) { let normalized = try normalizeSamples(Array(samples), sourceSampleRate: sourceSampleRate) audioBuffer.append(contentsOf: normalized) + _realSamplesReceived += normalized.count } else { audioBuffer.append(contentsOf: samples) + _realSamplesReceived += samples.count } preprocessAudioToFeaturesLocked() return try processLocked() @@ -409,6 +413,16 @@ public final class SortformerDiarizer: Diarizer { /// Internal process - caller must hold lock private func processLocked(updateTimeline: Bool = true) throws -> DiarizerTimelineUpdate? { + guard let chunk = try makeStreamingChunkLocked() else { + return nil + } + + _numFramesProcessed += chunk.finalizedFrameCount + guard updateTimeline else { return nil } + return try _timeline.addChunk(chunk) + } + + private func makeStreamingChunkLocked() throws -> DiarizerChunkResult? { guard let models = _models else { throw SortformerError.notInitialized } @@ -457,22 +471,94 @@ public final class SortformerDiarizer: Diarizer { } // Return new results if any - if newPredictions.count > 0 && updateTimeline { - let chunk = DiarizerChunkResult( + if newPredictions.count > 0 { + return DiarizerChunkResult( startFrame: _numFramesProcessed, finalizedPredictions: newPredictions, finalizedFrameCount: newFrameCount, tentativePredictions: newTentativePredictions, tentativeFrameCount: newTentativeFrameCount ) - - _numFramesProcessed += newFrameCount - return try _timeline.addChunk(chunk) } return nil } + /// Finalize the current streaming session. + /// + /// Pads the tail with silence until the last true frame has been emitted as + /// finalized output, then finalizes the timeline. + /// + /// - Returns: The last finalized chunk emitted during finalization, if any. + @discardableResult + public func finalizeSession() throws -> DiarizerChunkResult? { + return try lock.withLock { + guard _models != nil else { + throw SortformerError.notInitialized + } + + var lastResult: DiarizerChunkResult? + var tentativeToFlush = _timeline.numTentativeFrames + if let chunk = try makeStreamingChunkLocked() { + tentativeToFlush = chunk.tentativeFrameCount + let finalizedChunk = DiarizerChunkResult( + startFrame: chunk.startFrame, + finalizedPredictions: chunk.finalizedPredictions, + finalizedFrameCount: chunk.finalizedFrameCount, + tentativePredictions: [], + tentativeFrameCount: 0 + ) + _numFramesProcessed = finalizedChunk.startFrame + finalizedChunk.finalizedFrameCount + try _timeline.addChunk(finalizedChunk) + lastResult = finalizedChunk + } + + let targetEndFrame = _numFramesProcessed + min(tentativeToFlush, config.chunkLen) + let exactPaddingSamples = try exactFinalizationPaddingSamples(targetEndFrame: targetEndFrame) + if exactPaddingSamples > 0 { + audioBuffer.append(contentsOf: [Float](repeating: 0, count: exactPaddingSamples)) + } + + while _numFramesProcessed < targetEndFrame { + let remainingFrames = targetEndFrame - _numFramesProcessed + let targetFeatureFrames = + startFeat + + min(remainingFrames * config.subsamplingFactor, config.coreFrames) + + config.chunkRightContext * config.subsamplingFactor + preprocessAudioToFeatureTargetLocked(targetFeatureFrames: targetFeatureFrames) + guard let chunk = try makeStreamingChunkLocked(), chunk.finalizedFrameCount > 0 else { + logger.warning( + "Sortformer finalize could not emit enough confirmed frames " + + "(\(_numFramesProcessed) / \(targetEndFrame) frames)" + ) + break + } + + let finalizedFrameCount = min(remainingFrames, chunk.finalizedFrameCount) + guard finalizedFrameCount > 0 else { + break + } + + let finalizedPredictions = Array( + chunk.finalizedPredictions.prefix(finalizedFrameCount * config.numSpeakers) + ) + let finalizedResult = DiarizerChunkResult( + startFrame: _numFramesProcessed, + finalizedPredictions: finalizedPredictions, + finalizedFrameCount: finalizedFrameCount, + tentativePredictions: [], + tentativeFrameCount: 0 + ) + _numFramesProcessed += finalizedFrameCount + try _timeline.addChunk(finalizedResult) + lastResult = finalizedResult + } + + _timeline.finalize() + return lastResult + } + } + // MARK: - Complete File Processing /// Progress callback type: (processedSamples, totalSamples, chunksProcessed) @@ -660,38 +746,17 @@ public final class SortformerDiarizer: Diarizer { /// Preprocess audio into mel features - caller must hold lock private func preprocessAudioToFeaturesLocked() { + let targetFeatureFrames = startFeat + config.coreFrames + config.chunkRightContext * config.subsamplingFactor + preprocessAudioToFeatureTargetLocked(targetFeatureFrames: targetFeatureFrames) + } + + private func preprocessAudioToFeatureTargetLocked(targetFeatureFrames: Int) { guard !audioBuffer.isEmpty else { return } if audioBuffer.count < config.melWindow { return } - // Demand-Driven Optimization: - // Calculate exactly how many features we need for the next chunk - // needed = (startFeat + core + RC) - currentFeatureCount - let featLength = featureBuffer.count / config.melFeatures - let coreFrames = config.chunkLen * config.subsamplingFactor - let rightContextFrames = config.chunkRightContext * config.subsamplingFactor - - // Calculate absolute target position in feature stream - // For Chunk 0: startFeat=0. Target=104. - // For Chunk 1: startFeat=8. Target=112. - let targetEnd = startFeat + coreFrames + rightContextFrames - - let framesNeeded = targetEnd - featLength - - // If we already have enough frames, we don't strictly need to process more right now. - // However, to keep the pipeline moving smoothly, we can process if we have a full chunk buffered. - // But to strictly prioritize efficiency/latency balance as requested: - if framesNeeded <= 0 { - // We have enough features for the next chunk! - // Check if we have A LOT of audio buffered (buffer pressure)? - // If we have > 1 second of audio, maybe process it batch-wise? - // For now, lazy approach: don't process. - return - } - - // Calculate audio samples needed to produce 'framesNeeded' - // If we are appending to existing stream (featureBuffer not empty), we need stride * N. - // If featureBuffer is empty (start of stream), we need window + (N-1)*stride. + let framesNeeded = targetFeatureFrames - featLength + guard framesNeeded > 0 else { return } let samplesNeeded: Int if featureBuffer.isEmpty { @@ -700,12 +765,7 @@ public final class SortformerDiarizer: Diarizer { samplesNeeded = framesNeeded * config.melStride } - // Wait until we have enough audio to satisfy the demand - if audioBuffer.count < samplesNeeded { return } - - // We have enough audio! Process exactly what's needed (or slightly more if convenient?) - // Let's process everything we have, since we paid the initialization cost check. - // This prevents creating a backlog of unprocessed audio. + guard audioBuffer.count >= samplesNeeded else { return } let (mel, melLength, _) = melSpectrogram.computeFlatTransposed( audio: audioBuffer, @@ -741,6 +801,89 @@ public final class SortformerDiarizer: Diarizer { .resample(samples, from: sourceSampleRate) } + private func exactFinalizationPaddingSamples(targetEndFrame: Int) throws -> Int { + let remainingFrames = max(0, targetEndFrame - _numFramesProcessed) + let (remainingFeatureFrames, remainingOverflow) = remainingFrames.multipliedReportingOverflow( + by: config.subsamplingFactor + ) + guard !remainingOverflow else { + throw SortformerError.invalidState( + "Finalization remaining-frame expansion overflowed for \(remainingFrames) frames." + ) + } + + let (rightContextFeatureFrames, rightContextOverflow) = config.chunkRightContext.multipliedReportingOverflow( + by: config.subsamplingFactor + ) + guard !rightContextOverflow else { + throw SortformerError.invalidState( + "Finalization right-context expansion overflowed for \(config.chunkRightContext) frames." + ) + } + + let (targetWithoutContext, startOverflow) = startFeat.addingReportingOverflow(remainingFeatureFrames) + guard !startOverflow else { + throw SortformerError.invalidState( + "Finalization target feature frame calculation overflowed at startFeat=\(startFeat)." + ) + } + + let (targetFeatureFrames, contextOverflow) = targetWithoutContext.addingReportingOverflow( + rightContextFeatureFrames) + guard !contextOverflow else { + throw SortformerError.invalidState( + "Finalization target feature frame calculation overflowed after adding right context." + ) + } + + let currentFeatureFrames = featureBuffer.count / config.melFeatures + let additionalFeatureFramesNeeded = max(0, targetFeatureFrames - currentFeatureFrames) + guard additionalFeatureFramesNeeded > 0 else { + return 0 + } + + let framesAvailableWithoutPadding = producedMelFramesAvailable() + guard additionalFeatureFramesNeeded > framesAvailableWithoutPadding else { + return 0 + } + + let requiredBufferedSamples: Int + if featureBuffer.isEmpty { + let (additionalSamples, overflow) = max(0, additionalFeatureFramesNeeded - 1).multipliedReportingOverflow( + by: config.melStride + ) + guard !overflow else { + throw SortformerError.invalidState( + "Finalization sample requirement overflowed for \(additionalFeatureFramesNeeded) feature frames." + ) + } + let (samples, windowOverflow) = additionalSamples.addingReportingOverflow(config.melWindow) + guard !windowOverflow else { + throw SortformerError.invalidState( + "Finalization sample requirement overflowed after adding melWindow." + ) + } + requiredBufferedSamples = samples + } else { + let (samples, overflow) = additionalFeatureFramesNeeded.multipliedReportingOverflow(by: config.melStride) + guard !overflow else { + throw SortformerError.invalidState( + "Finalization sample requirement overflowed for \(additionalFeatureFramesNeeded) buffered frames." + ) + } + requiredBufferedSamples = max(config.melWindow, samples) + } + + return max(0, requiredBufferedSamples - audioBuffer.count) + } + + private func producedMelFramesAvailable() -> Int { + guard audioBuffer.count >= config.melWindow else { + return 0 + } + return audioBuffer.count / config.melStride + } + /// Get next chunk features (for testing) internal func getNextChunkFeatures() -> (mel: [Float], melLength: Int)? { lock.lock() diff --git a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift index f584f5af3..e4c7f2d55 100644 --- a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift +++ b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift @@ -4,13 +4,14 @@ import XCTest @testable import FluidAudio +@MainActor final class LSEENDIntegrationTests: XCTestCase { private struct ErrorStats { let maxAbs: Double let meanAbs: Double } - nonisolated(unsafe) private static var cachedEngines: [LSEENDVariant: LSEENDInferenceHelper] = [:] + private static var cachedEngines: [LSEENDVariant: LSEENDInferenceHelper] = [:] func testVariantRegistryResolvesAllExportedArtifacts() async throws { let expectedColumns: [LSEENDVariant: Int] = [ @@ -111,6 +112,54 @@ final class LSEENDIntegrationTests: XCTestCase { assertMatrixClose(snapshot.fullLogits, offline.fullLogits, maxAbs: 1e-5, meanAbs: 1e-6) } + func testStreamingFinalizationUsesExactPaddingForTailFlush() async throws { + let engine = try await makeEngine(variant: .dihard3) + let sampleRate = engine.targetSampleRate + let stableBlockSize = engine.metadata.resolvedHopLength * engine.metadata.resolvedSubsampling + let rawSamples = try DiarizationTestFixtures.fixtureAudio(sampleRate: sampleRate, limitSeconds: 6.0) + let sampleCount = stableBlockSize * 2 + stableBlockSize / 2 + 37 + let samples = Array(rawSamples.prefix(sampleCount)) + let expected = try engine.infer(samples: samples, sampleRate: sampleRate) + + let targetEndFrame = Int(round(Double(samples.count) / Double(sampleRate) * engine.modelFrameHz)) + let expectedPaddingSamples = max(0, targetEndFrame * stableBlockSize - samples.count) + + XCTAssertGreaterThan(expectedPaddingSamples, 0) + XCTAssertEqual((samples.count + expectedPaddingSamples) % stableBlockSize, 0) + + let session = try engine.createSession(inputSampleRate: sampleRate) + _ = try session.pushAudio(samples) + let finalUpdate = try session.finalize() + + XCTAssertNotNil(finalUpdate) + XCTAssertEqual(finalUpdate?.previewLogits.rows, 0) + XCTAssertEqual(finalUpdate?.previewProbabilities.rows, 0) + XCTAssertNil(try session.finalize()) + + let snapshot = session.snapshot() + XCTAssertEqual(snapshot.probabilities.rows, expected.probabilities.rows) + XCTAssertEqual(snapshot.logits.rows, expected.logits.rows) + assertMatrixClose(snapshot.logits, expected.logits, maxAbs: 1e-5, meanAbs: 1e-6) + assertMatrixClose(snapshot.probabilities, expected.probabilities, maxAbs: 1e-5, meanAbs: 1e-6) + assertMatrixClose(snapshot.fullLogits, expected.fullLogits, maxAbs: 1e-5, meanAbs: 1e-6) + assertMatrixClose(snapshot.fullProbabilities, expected.fullProbabilities, maxAbs: 1e-5, meanAbs: 1e-6) + } + + func testEmptyAudioFinalizationProducesNoOutput() async throws { + let engine = try await makeEngine(variant: .dihard3) + let session = try engine.createSession(inputSampleRate: engine.targetSampleRate) + + XCTAssertNil(try session.finalize()) + XCTAssertNil(try session.finalize()) + + let snapshot = session.snapshot() + XCTAssertEqual(snapshot.logits.rows, 0) + XCTAssertEqual(snapshot.probabilities.rows, 0) + XCTAssertEqual(snapshot.fullLogits.rows, 0) + XCTAssertEqual(snapshot.fullProbabilities.rows, 0) + XCTAssertEqual(snapshot.durationSeconds, 0) + } + func testStreamingSimulationMatchesOfflineInferenceAndReportsMonotonicProgress() async throws { let engine = try await makeEngine(variant: .dihard3) let fixtureURL = try DiarizationTestFixtures.fixtureAudioFileURL() @@ -166,10 +215,13 @@ final class LSEENDIntegrationTests: XCTestCase { for chunk in DiarizationTestFixtures.chunk(samples, sizes: [701, 977, 1153]) { let _ = try diarizer.process(samples: chunk) } - let _ = try diarizer.finalizeSession() + let finalChunk = try diarizer.finalizeSession() XCTAssertEqual(diarizer.numFramesProcessed, expected.probabilities.rows) XCTAssertEqual(diarizer.timeline.numFinalizedFrames, expected.probabilities.rows) + XCTAssertEqual(finalChunk?.tentativeFrameCount, 0) + XCTAssertEqual(finalChunk?.tentativePredictions.count, 0) + XCTAssertEqual(diarizer.timeline.numTentativeFrames, 0) assertArrayClose( diarizer.timeline.finalizedPredictions, expected.probabilities.values, maxAbs: 1e-5, meanAbs: 1e-6) @@ -247,6 +299,49 @@ final class LSEENDIntegrationTests: XCTestCase { XCTAssertFalse(diarizer.hasActiveSession) } + func testDiarizerTimelineCountAndDurationPropertiesReflectFrames() throws { + let frameDurationSeconds: Float = 0.25 + let timeline = DiarizerTimeline(config: .default(numSpeakers: 3, frameDurationSeconds: frameDurationSeconds)) + + XCTAssertEqual(timeline.numFinalizedFrames, 0) + XCTAssertEqual(timeline.numTentativeFrames, 0) + XCTAssertEqual(timeline.numFrames, 0) + XCTAssertEqual(timeline.finalizedDuration, 0) + XCTAssertEqual(timeline.tentativeDuration, 0) + XCTAssertEqual(timeline.duration, 0) + + let chunk = DiarizerChunkResult( + startFrame: 5, + finalizedPredictions: [ + 0.10, 0.20, 0.30, + 0.40, 0.50, 0.60, + ], + finalizedFrameCount: 2, + tentativePredictions: [ + 0.70, 0.80, 0.90, + ], + tentativeFrameCount: 1 + ) + + try timeline.addChunk(chunk) + + XCTAssertEqual(timeline.numFinalizedFrames, 2) + XCTAssertEqual(timeline.numTentativeFrames, 1) + XCTAssertEqual(timeline.numFrames, 3) + XCTAssertEqual(timeline.finalizedDuration, 0.5, accuracy: 1e-6) + XCTAssertEqual(timeline.tentativeDuration, 0.25, accuracy: 1e-6) + XCTAssertEqual(timeline.duration, 0.75, accuracy: 1e-6) + + timeline.finalize() + + XCTAssertEqual(timeline.numFinalizedFrames, 3) + XCTAssertEqual(timeline.numTentativeFrames, 0) + XCTAssertEqual(timeline.numFrames, 3) + XCTAssertEqual(timeline.finalizedDuration, 0.75, accuracy: 1e-6) + XCTAssertEqual(timeline.tentativeDuration, 0, accuracy: 1e-6) + XCTAssertEqual(timeline.duration, 0.75, accuracy: 1e-6) + } + private func makeEngine(variant: LSEENDVariant) async throws -> LSEENDInferenceHelper { if let cached = Self.cachedEngines[variant] { return cached diff --git a/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift new file mode 100644 index 000000000..76a9ca54b --- /dev/null +++ b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift @@ -0,0 +1,78 @@ +@preconcurrency @testable import FluidAudio +import XCTest + +@MainActor +final class SortformerStreamingIntegrationTests: XCTestCase { + private static var cachedModels: SortformerModels? + + private func loadModelsForTest(config: SortformerConfig) async throws -> SortformerModels { + if let cachedModels = Self.cachedModels { + return cachedModels + } + + let models = try await SortformerModels.loadFromHuggingFace(config: config, computeUnits: .cpuOnly) + Self.cachedModels = models + return models + } + + func testFinalizeSessionMatchesProcessCompleteFrameCount() async throws { + let config = SortformerConfig.default + let models: SortformerModels + do { + models = try await loadModelsForTest(config: config) + } catch { + throw XCTSkip("Sortformer models unavailable in this environment: \(error)") + } + let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: config.sampleRate, limitSeconds: 4.0) + + let streamingDiarizer = SortformerDiarizer(config: config) + streamingDiarizer.initialize(models: models) + for chunk in DiarizationTestFixtures.chunk(samples, sizes: [4_800, 7_680, 9_600]) { + let _ = try streamingDiarizer.process(samples: chunk) + } + let finalChunk = try streamingDiarizer.finalizeSession() + + let completeDiarizer = SortformerDiarizer(config: config) + completeDiarizer.initialize(models: models) + let expectedTimeline = try completeDiarizer.processComplete(samples) + + XCTAssertLessThanOrEqual( + abs(streamingDiarizer.timeline.numFinalizedFrames - expectedTimeline.numFinalizedFrames), + 1 + ) + XCTAssertEqual(finalChunk?.tentativeFrameCount, 0) + XCTAssertEqual(finalChunk?.tentativePredictions.count, 0) + XCTAssertEqual(streamingDiarizer.timeline.numTentativeFrames, 0) + } + + func testFinalizeSessionFlushesTentativeTailAfterAddAudioOnly() async throws { + let config = SortformerConfig.default + let models: SortformerModels + do { + models = try await loadModelsForTest(config: config) + } catch { + throw XCTSkip("Sortformer models unavailable in this environment: \(error)") + } + let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: config.sampleRate, limitSeconds: 4.0) + + let bufferedDiarizer = SortformerDiarizer(config: config) + bufferedDiarizer.initialize(models: models) + bufferedDiarizer.addAudio(samples) + let bufferedFinalChunk = try bufferedDiarizer.finalizeSession() + + let streamingDiarizer = SortformerDiarizer(config: config) + streamingDiarizer.initialize(models: models) + for chunk in DiarizationTestFixtures.chunk(samples, sizes: [4_800, 7_680, 9_600]) { + let _ = try streamingDiarizer.process(samples: chunk) + } + let _ = try streamingDiarizer.finalizeSession() + + XCTAssertNotNil(bufferedFinalChunk) + XCTAssertEqual(bufferedFinalChunk?.tentativeFrameCount, 0) + XCTAssertEqual(bufferedDiarizer.timeline.numTentativeFrames, 0) + XCTAssertEqual( + bufferedDiarizer.timeline.numFinalizedFrames, + streamingDiarizer.timeline.numFinalizedFrames + ) + } +} diff --git a/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerTimelineTests.swift b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerTimelineTests.swift index 62da8b5a4..bb07c5b7c 100644 --- a/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerTimelineTests.swift +++ b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerTimelineTests.swift @@ -136,6 +136,66 @@ final class SortformerTimelineTests: XCTestCase { XCTAssertTrue(timeline.tentativePredictions.isEmpty) } + func testSegmentConfidenceExcludesPaddingFrames() throws { + let config = DiarizerTimelineConfig( + numSpeakers: 1, + frameDurationSeconds: 0.08, + onsetThreshold: 0.5, + offsetThreshold: 0.5, + onsetPadFrames: 1, + offsetPadFrames: 2, + minFramesOn: 0, + minFramesOff: 0 + ) + let timeline = DiarizerTimeline(config: config) + let predictions: [Float] = [0.0, 0.8, 0.6, 0.0] + + try timeline.addChunk( + DiarizerChunkResult( + startFrame: 0, + finalizedPredictions: predictions, + finalizedFrameCount: predictions.count + ) + ) + + timeline.finalize() + + let segment = try XCTUnwrap(timeline.speakers[0]?.finalizedSegments.first) + XCTAssertEqual(segment.startFrame, 0) + XCTAssertEqual(segment.endFrame, 5) + XCTAssertEqual(segment.confidence, 0.7, accuracy: 1e-6) + } + + func testSegmentConfidenceExcludesBridgedGapFrames() throws { + let config = DiarizerTimelineConfig( + numSpeakers: 1, + frameDurationSeconds: 0.08, + onsetThreshold: 0.5, + offsetThreshold: 0.5, + onsetPadFrames: 0, + offsetPadFrames: 0, + minFramesOn: 0, + minFramesOff: 1 + ) + let timeline = DiarizerTimeline(config: config) + let predictions: [Float] = [0.9, 0.0, 0.7, 0.7, 0.0] + + try timeline.addChunk( + DiarizerChunkResult( + startFrame: 0, + finalizedPredictions: predictions, + finalizedFrameCount: predictions.count + ) + ) + + timeline.finalize() + + let segment = try XCTUnwrap(timeline.speakers[0]?.finalizedSegments.first) + XCTAssertEqual(segment.startFrame, 0) + XCTAssertEqual(segment.endFrame, 4) + XCTAssertEqual(segment.confidence, (0.9 + 0.7 + 0.7) / 3.0, accuracy: 1e-6) + } + // MARK: - Probability Access func testProbabilityAccess() throws { diff --git a/Tests/FluidAudioTests/Diarizer/SpeakerEnrollmentTests.swift b/Tests/FluidAudioTests/Diarizer/SpeakerEnrollmentTests.swift index 5f3634ef1..68b359a72 100644 --- a/Tests/FluidAudioTests/Diarizer/SpeakerEnrollmentTests.swift +++ b/Tests/FluidAudioTests/Diarizer/SpeakerEnrollmentTests.swift @@ -212,6 +212,50 @@ final class SpeakerEnrollmentTests: XCTestCase { XCTAssertEqual(namedSpeakerIndices(in: diarizer.timeline), [speaker?.index].compactMap { $0 }) } + func testSortformerEnrollmentClearsDiscardedSampleCountBeforeFinalize() async throws { + XCTExpectFailure("Download might fail in CI environment", strict: false) + + let config = SortformerConfig.default + let models = try await loadSortformerModelsForTest(config: config) + + let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0) + let discardedAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: config.sampleRate, startSeconds: 5.0, durationSeconds: 1.5) + let liveAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: config.sampleRate, startSeconds: 6.5, durationSeconds: 3.0) + + let dirtyDiarizer = SortformerDiarizer(config: config) + dirtyDiarizer.initialize(models: models) + dirtyDiarizer.addAudio(discardedAudio) + let enrolledSpeaker = try dirtyDiarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice") + try XCTSkipIf( + enrolledSpeaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.") + + for chunk in DiarizationTestFixtures.chunk(liveAudio, sizes: [4_800, 7_680, 9_600]) { + let _ = try dirtyDiarizer.process(samples: chunk) + } + let _ = try dirtyDiarizer.finalizeSession() + + let cleanDiarizer = SortformerDiarizer(config: config) + cleanDiarizer.initialize(models: models) + let cleanSpeaker = try cleanDiarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice") + try XCTSkipIf( + cleanSpeaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.") + + for chunk in DiarizationTestFixtures.chunk(liveAudio, sizes: [4_800, 7_680, 9_600]) { + let _ = try cleanDiarizer.process(samples: chunk) + } + let _ = try cleanDiarizer.finalizeSession() + + XCTAssertLessThanOrEqual( + abs(dirtyDiarizer.timeline.numFinalizedFrames - cleanDiarizer.timeline.numFinalizedFrames), + 1 + ) + XCTAssertEqual(dirtyDiarizer.timeline.numTentativeFrames, 0) + XCTAssertEqual(cleanDiarizer.timeline.numTentativeFrames, 0) + } + func testSortformerMultipleEnrollmentsRetainNamedSpeakersAndState() async throws { XCTExpectFailure("Download might fail in CI environment", strict: false)