From 59a63e2d952be77e0da16d3c6e07f1afeae22c99 Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Tue, 24 Mar 2026 19:43:43 -0700 Subject: [PATCH 01/12] Streamline diarizer finalization changes --- Documentation/API.md | 2 +- Documentation/Diarization/DiarizerTimeline.md | 2 +- Documentation/Diarization/LS-EEND.md | 6 +- Documentation/Diarization/Sortformer.md | 8 ++ .../Diarizer/LS-EEND/LSEENDDiarizerAPI.swift | 36 +++--- .../LS-EEND/LSEENDModelInference.swift | 48 +++++--- .../SortformerDiarizerPipeline.swift | 109 ++++++++++++++++-- .../SortformerStreamingIntegrationTests.swift | 45 ++++++++ .../Diarizer/SpeakerEnrollmentTests.swift | 44 +++++++ 9 files changed, 256 insertions(+), 44 deletions(-) create mode 100644 Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift diff --git a/Documentation/API.md b/Documentation/API.md index b8a46e237..95ab43c57 100644 --- a/Documentation/API.md +++ b/Documentation/API.md @@ -83,7 +83,7 @@ Use `OfflineDiarizerManager` when you need offline DER parity or want to run the **Speaker Enrollment:** `enrollSpeaker(withAudio:sourceSampleRate:named:...)` feeds known-speaker audio before streaming to label a slot. -**Lifecycle:** `reset()` clears streaming state but keeps the model loaded. `cleanup()` releases everything. +**Lifecycle:** `finalizeSession()` flushes trailing context so the last true frame becomes finalized. `reset()` clears streaming state but keeps the model loaded. `cleanup()` releases everything. --- diff --git a/Documentation/Diarization/DiarizerTimeline.md b/Documentation/Diarization/DiarizerTimeline.md index 54933b7fa..c80a639ba 100644 --- a/Documentation/Diarization/DiarizerTimeline.md +++ b/Documentation/Diarization/DiarizerTimeline.md @@ -157,4 +157,4 @@ When `timeline.addChunk(_:)` is called internally by the diarizer: 3. It iterates over all `DiarizerSpeaker` tracks, evaluating the boundaries (using `onsetThreshold` and `offsetThreshold`) to grow existing segments or spawn new ones. 4. Tentative segments are cleared and rebuilt from the trailing `tentativePredictions` array during every streaming tick. -When the stream naturally finishes, the `Diarizer` automatically invokes `timeline.finalize()`, which flushes any remaining tentative segments up to finalized status and applies the `minFramesOn` deletion rules. +When the stream naturally finishes, call `Diarizer.finalizeSession()`. The diarizer flushes trailing context first, then invokes `timeline.finalize()`, which promotes any remaining tentative segments to finalized status and applies the `minFramesOn` deletion rules. diff --git a/Documentation/Diarization/LS-EEND.md b/Documentation/Diarization/LS-EEND.md index 3e357b326..b8f740e46 100644 --- a/Documentation/Diarization/LS-EEND.md +++ b/Documentation/Diarization/LS-EEND.md @@ -226,11 +226,14 @@ if let update = try diarizer.process() { // Convenience: add + process in one call if let update = try diarizer.process(samples: audioChunk) { ... } -// Flush remaining frames at end of stream +// Flush remaining committed + preview frames at end of stream try diarizer.finalizeSession() let finalTimeline = diarizer.timeline ``` +Notes: +- `finalizeSession()` pads the tail with silence when needed so the last true frame is emitted as finalized output before the timeline is finalized. + ### Speaker Enrollment Use speaker enrollment to warm LS-EEND with a known speaker before the live stream starts. Enrollment keeps the active streaming session, resets the visible timeline back to frame 0, and preserves the speaker name inside the `DiarizerTimeline`. @@ -288,6 +291,7 @@ Real-world integration testing with 4-speaker audio reveals specific enrollment ### Lifecycle ```swift +try diarizer.finalizeSession() // Flush trailing context before reading final output diarizer.reset() // Reset streaming state for a new audio stream (keeps model loaded) diarizer.cleanup() // Release all resources including the loaded model ``` diff --git a/Documentation/Diarization/Sortformer.md b/Documentation/Diarization/Sortformer.md index 0c9812411..14733d3f1 100644 --- a/Documentation/Diarization/Sortformer.md +++ b/Documentation/Diarization/Sortformer.md @@ -368,6 +368,10 @@ public struct SortformerSegment { │ └─→ timeline.addChunk(result) │ │ └─→ Update segments per speaker │ │ │ +│ 3. finalizeSession() │ +│ └─→ pad trailing silence until last true frame is emitted │ +│ └─→ timeline.finalize() │ +│ │ └────────────────────────────────────────────────────────────────┘ ``` @@ -444,6 +448,8 @@ audioEngine.installTap { buffer in updateSpeakerDisplay(diarizer.timeline) } } + +try diarizer.finalizeSession() ``` ### Batch Processing @@ -466,6 +472,8 @@ for (index, speaker) in timeline.speakers { } ``` +`finalizeSession()` is only needed for streaming mode. It pads enough trailing silence to flush Sortformer's right-context preview frames, then finalizes the timeline so `numTentativeFrames == 0`. + ### Speaker Enrollment Use speaker enrollment to warm Sortformer with known speakers before live audio starts. Enrollment preserves the speaker cache / FIFO state, resets the visible timeline, and keeps the speaker name in the `DiarizerTimeline`. diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift index 0a9aaec94..0dd3986a9 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift @@ -678,6 +678,8 @@ public final class LSEENDDiarizer: Diarizer { defer { lock.unlock() } guard let engine = _engine, let session = _session else { return nil } + let numSpeakers = engine.metadata.realOutputDim + var lastResult: DiarizerChunkResult? // Flush pending audio first — clear unconditionally so failed audio isn't retained. // Using defer + direct pass avoids a CoW copy. @@ -685,39 +687,35 @@ public final class LSEENDDiarizer: Diarizer { defer { pendingAudio.removeAll(keepingCapacity: true) } let pushedUpdate = try session.pushAudio(pendingAudio) if let update = pushedUpdate { - let numSpeakers = engine.metadata.realOutputDim let flushedResult = DiarizerChunkResult( startFrame: _numFramesProcessed, finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers), finalizedFrameCount: update.probabilities.rows, - tentativePredictions: [], - tentativeFrameCount: 0 + tentativePredictions: flattenRowMajor(update.previewProbabilities, numSpeakers: numSpeakers), + tentativeFrameCount: update.previewProbabilities.rows ) _numFramesProcessed += flushedResult.finalizedFrameCount try _timeline.addChunk(flushedResult) + lastResult = flushedResult } } - guard let finalUpdate = try session.finalize() else { - _session = nil - _timeline.finalize() - return nil + if let finalUpdate = try session.finalize() { + let finalResult = DiarizerChunkResult( + startFrame: _numFramesProcessed, + finalizedPredictions: flattenRowMajor(finalUpdate.probabilities, numSpeakers: numSpeakers), + finalizedFrameCount: finalUpdate.probabilities.rows, + tentativePredictions: [], + tentativeFrameCount: 0 + ) + _numFramesProcessed += finalResult.finalizedFrameCount + try _timeline.addChunk(finalResult) + lastResult = finalResult } - - let numSpeakers = engine.metadata.realOutputDim - let result = DiarizerChunkResult( - startFrame: _numFramesProcessed, - finalizedPredictions: flattenRowMajor(finalUpdate.probabilities, numSpeakers: numSpeakers), - finalizedFrameCount: finalUpdate.probabilities.rows, - tentativePredictions: [], - tentativeFrameCount: 0 - ) - _numFramesProcessed += result.finalizedFrameCount - try _timeline.addChunk(result) _timeline.finalize() _session = nil - return result + return lastResult } // MARK: - Private diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift index 69afd6a36..e1d593b2e 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift @@ -208,17 +208,11 @@ public final class LSEENDInferenceHelper { /// - Returns: Complete inference result with logits and probabilities. public func infer(samples: [Float], sampleRate: Int) throws -> LSEENDInferenceResult { let normalizedAudio = try resampleIfNeeded(samples: samples, sampleRate: sampleRate) - let features = try offlineFeatureExtractor.extractFeatures(audio: normalizedAudio) let session = try createSession(inputSampleRate: targetSampleRate) - session.totalInputSamples = normalizedAudio.count - let committed = try session.ingestFeatures(features) - let pending = session.totalFeatureFrames - session.emittedFrames - let tail = - try pending > 0 - ? session.flushTail(from: session.state, pendingFrames: pending) : .empty(columns: decodeMaxSpeakers) - let fullLogits = committed.appendingRows(tail) - session.fullLogitChunks = fullLogits.isEmpty ? [] : [fullLogits] - session.emittedFrames = fullLogits.rows + if !normalizedAudio.isEmpty { + _ = try session.pushAudio(normalizedAudio) + } + _ = try session.finalize() return session.snapshot() } @@ -518,13 +512,41 @@ public final class LSEENDStreamingSession { guard !finalized else { return nil } - let features = try featureExtractor.finalize() - let committed = try ingestFeatures(features) + + var committedFullLogits = LSEENDMatrix.empty(columns: engine.decodeMaxSpeakers) + let targetEndFrame = max( + totalFeatureFrames, + Int(round(Double(totalInputSamples) / Double(max(inputSampleRate, 1)) * engine.modelFrameHz)) + ) + let paddingSamples = max( + 1, + Int(ceil(max(engine.streamingLatencySeconds, 1.0 / engine.modelFrameHz) * Double(inputSampleRate))) + ) + + var stalledPasses = 0 + while emittedFrames < targetEndFrame && stalledPasses < 4 { + let features = try featureExtractor.pushAudio([Float](repeating: 0, count: paddingSamples)) + let committed = try ingestFeatures(features) + if committed.rows > 0 { + committedFullLogits = committedFullLogits.appendingRows(committed) + stalledPasses = 0 + } else { + stalledPasses += 1 + } + } + + let finalFeatures = try featureExtractor.finalize() + let finalCommitted = try ingestFeatures(finalFeatures) + if finalCommitted.rows > 0 { + committedFullLogits = committedFullLogits.appendingRows(finalCommitted) + } + let pending = totalFeatureFrames - emittedFrames let tail = try pending > 0 ? flushTail(from: state, pendingFrames: pending) : .empty(columns: engine.decodeMaxSpeakers) + emittedFrames += tail.rows finalized = true - return try buildUpdate(committedFullLogits: committed.appendingRows(tail), includePreview: false) + return try buildUpdate(committedFullLogits: committedFullLogits.appendingRows(tail), includePreview: false) } /// Assembles the full inference result from all committed frames emitted so far. diff --git a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift index 2c499304b..340dc818b 100644 --- a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift +++ b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift @@ -44,6 +44,7 @@ public final class SortformerDiarizer: Diarizer { return _numFramesProcessed } private var _numFramesProcessed: Int = 0 + private var _realSamplesReceived: Int = 0 /// Configuration public let config: SortformerConfig @@ -151,6 +152,7 @@ public final class SortformerDiarizer: Diarizer { startFeat = 0 diarizerChunkIndex = 0 _numFramesProcessed = 0 + _realSamplesReceived = 0 _timeline.reset(keepingSpeakers: keepingSpeakers) featureBuffer.reserveCapacity((config.chunkMelFrames + config.coreFrames) * config.melFeatures) @@ -253,6 +255,7 @@ public final class SortformerDiarizer: Diarizer { diarizerChunkIndex = 0 audioBuffer.removeAll(keepingCapacity: true) featureBuffer.removeAll(keepingCapacity: true) + _realSamplesReceived = 0 audioBuffer.append(contentsOf: normalized) preprocessAudioToFeaturesLocked() @@ -285,6 +288,7 @@ public final class SortformerDiarizer: Diarizer { lastAudioSample = 0 audioBuffer.removeAll(keepingCapacity: true) featureBuffer.removeAll(keepingCapacity: true) + _realSamplesReceived = 0 return nil } logger.warning( @@ -306,6 +310,7 @@ public final class SortformerDiarizer: Diarizer { lastAudioSample = 0 audioBuffer.removeAll(keepingCapacity: true) featureBuffer.removeAll(keepingCapacity: true) + _realSamplesReceived = 0 logger.info( "Enrolled speaker \(description) with \(normalized.count) samples " @@ -327,6 +332,7 @@ public final class SortformerDiarizer: Diarizer { public func addAudio(_ samples: [Float]) { lock.withLock { audioBuffer.append(contentsOf: samples) + _realSamplesReceived += samples.count preprocessAudioToFeaturesLocked() } } @@ -343,6 +349,7 @@ public final class SortformerDiarizer: Diarizer { let normalized = try normalizeSamples(samples, sourceSampleRate: sourceSampleRate) lock.withLock { audioBuffer.append(contentsOf: normalized) + _realSamplesReceived += normalized.count preprocessAudioToFeaturesLocked() } } @@ -360,8 +367,10 @@ public final class SortformerDiarizer: Diarizer { if let sourceSampleRate, sourceSampleRate != Double(config.sampleRate) { let normalized = try normalizeSamples(Array(samples), sourceSampleRate: sourceSampleRate) audioBuffer.append(contentsOf: normalized) + _realSamplesReceived += normalized.count } else { audioBuffer.append(contentsOf: samples) + _realSamplesReceived += samples.count } preprocessAudioToFeaturesLocked() } @@ -378,10 +387,6 @@ public final class SortformerDiarizer: Diarizer { } } - /// Process a chunk of audio in one call. - /// - /// Convenience method that combines `addAudio()` and `process()`. - /// /// Process a chunk of audio in one call. /// /// Convenience method that combines `addAudio()` and `process()`. @@ -399,8 +404,10 @@ public final class SortformerDiarizer: Diarizer { if let sourceSampleRate, sourceSampleRate != Double(config.sampleRate) { let normalized = try normalizeSamples(Array(samples), sourceSampleRate: sourceSampleRate) audioBuffer.append(contentsOf: normalized) + _realSamplesReceived += normalized.count } else { audioBuffer.append(contentsOf: samples) + _realSamplesReceived += samples.count } preprocessAudioToFeaturesLocked() return try processLocked() @@ -409,6 +416,16 @@ public final class SortformerDiarizer: Diarizer { /// Internal process - caller must hold lock private func processLocked(updateTimeline: Bool = true) throws -> DiarizerTimelineUpdate? { + guard let chunk = try makeStreamingChunkLocked() else { + return nil + } + + _numFramesProcessed += chunk.finalizedFrameCount + guard updateTimeline else { return nil } + return try _timeline.addChunk(chunk) + } + + private func makeStreamingChunkLocked() throws -> DiarizerChunkResult? { guard let models = _models else { throw SortformerError.notInitialized } @@ -457,22 +474,96 @@ public final class SortformerDiarizer: Diarizer { } // Return new results if any - if newPredictions.count > 0 && updateTimeline { - let chunk = DiarizerChunkResult( + if newPredictions.count > 0 { + return DiarizerChunkResult( startFrame: _numFramesProcessed, finalizedPredictions: newPredictions, finalizedFrameCount: newFrameCount, tentativePredictions: newTentativePredictions, tentativeFrameCount: newTentativeFrameCount ) - - _numFramesProcessed += newFrameCount - return try _timeline.addChunk(chunk) } return nil } + /// Finalize the current streaming session. + /// + /// Pads the tail with silence until the last true frame has been emitted as + /// finalized output, then finalizes the timeline. + /// + /// - Returns: The last finalized chunk emitted during finalization, if any. + @discardableResult + public func finalizeSession() throws -> DiarizerChunkResult? { + try lock.withLock { + guard _models != nil else { + throw SortformerError.notInitialized + } + + var lastResult: DiarizerChunkResult? + if let chunk = try makeStreamingChunkLocked() { + _numFramesProcessed = chunk.startFrame + chunk.finalizedFrameCount + try _timeline.addChunk(chunk) + lastResult = chunk + } + + let targetEndFrame = max( + _numFramesProcessed + _timeline.numTentativeFrames, + Int( + round( + Double(_realSamplesReceived) * (1.0 / Double(config.frameDurationSeconds)) + / Double(config.sampleRate))) + ) + let paddingSamples = max( + config.melWindow, + config.chunkRightContext * config.subsamplingFactor * config.melStride + ) + + var stalledPasses = 0 + while _numFramesProcessed < targetEndFrame && stalledPasses < 4 { + audioBuffer.append(contentsOf: [Float](repeating: 0, count: paddingSamples)) + preprocessAudioToFeaturesLocked() + + guard let chunk = try makeStreamingChunkLocked(), chunk.finalizedFrameCount > 0 else { + stalledPasses += 1 + continue + } + + let remainingFrames = targetEndFrame - _numFramesProcessed + let finalizedFrameCount = min(remainingFrames, chunk.finalizedFrameCount) + guard finalizedFrameCount > 0 else { + stalledPasses += 1 + continue + } + + let finalizedPredictions = Array( + chunk.finalizedPredictions.prefix(finalizedFrameCount * config.numSpeakers) + ) + let finalizedResult = DiarizerChunkResult( + startFrame: _numFramesProcessed, + finalizedPredictions: finalizedPredictions, + finalizedFrameCount: finalizedFrameCount, + tentativePredictions: [], + tentativeFrameCount: 0 + ) + _numFramesProcessed += finalizedFrameCount + try _timeline.addChunk(finalizedResult) + lastResult = finalizedResult + stalledPasses = 0 + } + + if _numFramesProcessed < targetEndFrame { + logger.warning( + "Sortformer finalize stalled before last true frame was confirmed " + + "(\(_numFramesProcessed) / \(targetEndFrame) frames)" + ) + } + + _timeline.finalize() + return lastResult + } + } + // MARK: - Complete File Processing /// Progress callback type: (processedSamples, totalSamples, chunksProcessed) diff --git a/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift new file mode 100644 index 000000000..0230a9d7f --- /dev/null +++ b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift @@ -0,0 +1,45 @@ +@preconcurrency @testable import FluidAudio +import XCTest + +@MainActor +final class SortformerStreamingIntegrationTests: XCTestCase { + private static var cachedModels: SortformerModels? + + private func loadModelsForTest(config: SortformerConfig) async throws -> SortformerModels { + if let cachedModels = Self.cachedModels { + return cachedModels + } + + let models = try await SortformerModels.loadFromHuggingFace(config: config, computeUnits: .cpuOnly) + Self.cachedModels = models + return models + } + + func testFinalizeSessionMatchesProcessCompleteFrameCount() async throws { + let config = SortformerConfig.default + let models: SortformerModels + do { + models = try await loadModelsForTest(config: config) + } catch { + throw XCTSkip("Sortformer models unavailable in this environment: \(error)") + } + let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: config.sampleRate, limitSeconds: 4.0) + + let streamingDiarizer = SortformerDiarizer(config: config) + streamingDiarizer.initialize(models: models) + for chunk in DiarizationTestFixtures.chunk(samples, sizes: [4_800, 7_680, 9_600]) { + let _ = try streamingDiarizer.process(samples: chunk) + } + let _ = try streamingDiarizer.finalizeSession() + + let completeDiarizer = SortformerDiarizer(config: config) + completeDiarizer.initialize(models: models) + let expectedTimeline = try completeDiarizer.processComplete(samples) + + XCTAssertLessThanOrEqual( + abs(streamingDiarizer.timeline.numFinalizedFrames - expectedTimeline.numFinalizedFrames), + 1 + ) + XCTAssertEqual(streamingDiarizer.timeline.numTentativeFrames, 0) + } +} diff --git a/Tests/FluidAudioTests/Diarizer/SpeakerEnrollmentTests.swift b/Tests/FluidAudioTests/Diarizer/SpeakerEnrollmentTests.swift index 5f3634ef1..68b359a72 100644 --- a/Tests/FluidAudioTests/Diarizer/SpeakerEnrollmentTests.swift +++ b/Tests/FluidAudioTests/Diarizer/SpeakerEnrollmentTests.swift @@ -212,6 +212,50 @@ final class SpeakerEnrollmentTests: XCTestCase { XCTAssertEqual(namedSpeakerIndices(in: diarizer.timeline), [speaker?.index].compactMap { $0 }) } + func testSortformerEnrollmentClearsDiscardedSampleCountBeforeFinalize() async throws { + XCTExpectFailure("Download might fail in CI environment", strict: false) + + let config = SortformerConfig.default + let models = try await loadSortformerModelsForTest(config: config) + + let enrollmentAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: config.sampleRate, startSeconds: 0.0, durationSeconds: 5.0) + let discardedAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: config.sampleRate, startSeconds: 5.0, durationSeconds: 1.5) + let liveAudio = try DiarizationTestFixtures.fixtureAudio( + sampleRate: config.sampleRate, startSeconds: 6.5, durationSeconds: 3.0) + + let dirtyDiarizer = SortformerDiarizer(config: config) + dirtyDiarizer.initialize(models: models) + dirtyDiarizer.addAudio(discardedAudio) + let enrolledSpeaker = try dirtyDiarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice") + try XCTSkipIf( + enrolledSpeaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.") + + for chunk in DiarizationTestFixtures.chunk(liveAudio, sizes: [4_800, 7_680, 9_600]) { + let _ = try dirtyDiarizer.process(samples: chunk) + } + let _ = try dirtyDiarizer.finalizeSession() + + let cleanDiarizer = SortformerDiarizer(config: config) + cleanDiarizer.initialize(models: models) + let cleanSpeaker = try cleanDiarizer.enrollSpeaker(withAudio: enrollmentAudio, named: "Alice") + try XCTSkipIf( + cleanSpeaker == nil, "Fixture did not produce a confident Sortformer speaker segment on this host.") + + for chunk in DiarizationTestFixtures.chunk(liveAudio, sizes: [4_800, 7_680, 9_600]) { + let _ = try cleanDiarizer.process(samples: chunk) + } + let _ = try cleanDiarizer.finalizeSession() + + XCTAssertLessThanOrEqual( + abs(dirtyDiarizer.timeline.numFinalizedFrames - cleanDiarizer.timeline.numFinalizedFrames), + 1 + ) + XCTAssertEqual(dirtyDiarizer.timeline.numTentativeFrames, 0) + XCTAssertEqual(cleanDiarizer.timeline.numTentativeFrames, 0) + } + func testSortformerMultipleEnrollmentsRetainNamedSpeakersAndState() async throws { XCTExpectFailure("Download might fail in CI environment", strict: false) From 4912f91b4a89801e9e9caee64f9d394d2777b995 Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Tue, 24 Mar 2026 19:46:42 -0700 Subject: [PATCH 02/12] Exclude tentative frames during finalization --- .../Diarizer/LS-EEND/LSEENDDiarizerAPI.swift | 4 ++-- .../Sortformer/SortformerDiarizerPipeline.swift | 13 ++++++++++--- .../Diarizer/LS-EEND/LSEENDIntegrationTests.swift | 5 ++++- .../SortformerStreamingIntegrationTests.swift | 4 +++- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift index 0dd3986a9..a87470514 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift @@ -691,8 +691,8 @@ public final class LSEENDDiarizer: Diarizer { startFrame: _numFramesProcessed, finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers), finalizedFrameCount: update.probabilities.rows, - tentativePredictions: flattenRowMajor(update.previewProbabilities, numSpeakers: numSpeakers), - tentativeFrameCount: update.previewProbabilities.rows + tentativePredictions: [], + tentativeFrameCount: 0 ) _numFramesProcessed += flushedResult.finalizedFrameCount try _timeline.addChunk(flushedResult) diff --git a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift index 340dc818b..7dc03f43e 100644 --- a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift +++ b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift @@ -502,9 +502,16 @@ public final class SortformerDiarizer: Diarizer { var lastResult: DiarizerChunkResult? if let chunk = try makeStreamingChunkLocked() { - _numFramesProcessed = chunk.startFrame + chunk.finalizedFrameCount - try _timeline.addChunk(chunk) - lastResult = chunk + let finalizedChunk = DiarizerChunkResult( + startFrame: chunk.startFrame, + finalizedPredictions: chunk.finalizedPredictions, + finalizedFrameCount: chunk.finalizedFrameCount, + tentativePredictions: [], + tentativeFrameCount: 0 + ) + _numFramesProcessed = finalizedChunk.startFrame + finalizedChunk.finalizedFrameCount + try _timeline.addChunk(finalizedChunk) + lastResult = finalizedChunk } let targetEndFrame = max( diff --git a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift index f584f5af3..b7aa15d24 100644 --- a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift +++ b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift @@ -166,10 +166,13 @@ final class LSEENDIntegrationTests: XCTestCase { for chunk in DiarizationTestFixtures.chunk(samples, sizes: [701, 977, 1153]) { let _ = try diarizer.process(samples: chunk) } - let _ = try diarizer.finalizeSession() + let finalChunk = try diarizer.finalizeSession() XCTAssertEqual(diarizer.numFramesProcessed, expected.probabilities.rows) XCTAssertEqual(diarizer.timeline.numFinalizedFrames, expected.probabilities.rows) + XCTAssertEqual(finalChunk?.tentativeFrameCount, 0) + XCTAssertEqual(finalChunk?.tentativePredictions.count, 0) + XCTAssertEqual(diarizer.timeline.numTentativeFrames, 0) assertArrayClose( diarizer.timeline.finalizedPredictions, expected.probabilities.values, maxAbs: 1e-5, meanAbs: 1e-6) diff --git a/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift index 0230a9d7f..4e2a87a22 100644 --- a/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift +++ b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift @@ -30,7 +30,7 @@ final class SortformerStreamingIntegrationTests: XCTestCase { for chunk in DiarizationTestFixtures.chunk(samples, sizes: [4_800, 7_680, 9_600]) { let _ = try streamingDiarizer.process(samples: chunk) } - let _ = try streamingDiarizer.finalizeSession() + let finalChunk = try streamingDiarizer.finalizeSession() let completeDiarizer = SortformerDiarizer(config: config) completeDiarizer.initialize(models: models) @@ -40,6 +40,8 @@ final class SortformerStreamingIntegrationTests: XCTestCase { abs(streamingDiarizer.timeline.numFinalizedFrames - expectedTimeline.numFinalizedFrames), 1 ) + XCTAssertEqual(finalChunk?.tentativeFrameCount, 0) + XCTAssertEqual(finalChunk?.tentativePredictions.count, 0) XCTAssertEqual(streamingDiarizer.timeline.numTentativeFrames, 0) } } From c79548c7d97157bacbc11e996bba729d4b0ec087 Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Tue, 24 Mar 2026 20:07:58 -0700 Subject: [PATCH 03/12] Use exact diarizer finalization padding --- .../LS-EEND/LSEENDModelInference.swift | 29 ++-- .../SortformerDiarizerPipeline.swift | 139 ++++++++++-------- 2 files changed, 89 insertions(+), 79 deletions(-) diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift index e1d593b2e..495fbb4e6 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift @@ -514,24 +514,14 @@ public final class LSEENDStreamingSession { } var committedFullLogits = LSEENDMatrix.empty(columns: engine.decodeMaxSpeakers) - let targetEndFrame = max( - totalFeatureFrames, - Int(round(Double(totalInputSamples) / Double(max(inputSampleRate, 1)) * engine.modelFrameHz)) - ) - let paddingSamples = max( - 1, - Int(ceil(max(engine.streamingLatencySeconds, 1.0 / engine.modelFrameHz) * Double(inputSampleRate))) - ) - - var stalledPasses = 0 - while emittedFrames < targetEndFrame && stalledPasses < 4 { - let features = try featureExtractor.pushAudio([Float](repeating: 0, count: paddingSamples)) + let targetEndFrame = Int( + round(Double(totalInputSamples) / Double(max(inputSampleRate, 1)) * engine.modelFrameHz)) + let exactPaddingSamples = exactFinalizationPaddingSamples(targetEndFrame: targetEndFrame) + if exactPaddingSamples > 0 { + let features = try featureExtractor.pushAudio([Float](repeating: 0, count: exactPaddingSamples)) let committed = try ingestFeatures(features) if committed.rows > 0 { committedFullLogits = committedFullLogits.appendingRows(committed) - stalledPasses = 0 - } else { - stalledPasses += 1 } } @@ -549,6 +539,15 @@ public final class LSEENDStreamingSession { return try buildUpdate(committedFullLogits: committedFullLogits.appendingRows(tail), includePreview: false) } + private func exactFinalizationPaddingSamples(targetEndFrame: Int) -> Int { + guard targetEndFrame > 0 else { + return 0 + } + let stableBlockSize = engine.metadata.resolvedHopLength * engine.metadata.resolvedSubsampling + let requiredTotalSamples = targetEndFrame * stableBlockSize + return max(0, requiredTotalSamples - totalInputSamples) + } + /// Assembles the full inference result from all committed frames emitted so far. /// /// Can be called at any time (before or after finalization) to get a complete diff --git a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift index 7dc03f43e..cdd343203 100644 --- a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift +++ b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift @@ -495,7 +495,7 @@ public final class SortformerDiarizer: Diarizer { /// - Returns: The last finalized chunk emitted during finalization, if any. @discardableResult public func finalizeSession() throws -> DiarizerChunkResult? { - try lock.withLock { + return try lock.withLock { guard _models != nil else { throw SortformerError.notInitialized } @@ -514,33 +514,30 @@ public final class SortformerDiarizer: Diarizer { lastResult = finalizedChunk } - let targetEndFrame = max( - _numFramesProcessed + _timeline.numTentativeFrames, - Int( - round( - Double(_realSamplesReceived) * (1.0 / Double(config.frameDurationSeconds)) - / Double(config.sampleRate))) - ) - let paddingSamples = max( - config.melWindow, - config.chunkRightContext * config.subsamplingFactor * config.melStride - ) - - var stalledPasses = 0 - while _numFramesProcessed < targetEndFrame && stalledPasses < 4 { - audioBuffer.append(contentsOf: [Float](repeating: 0, count: paddingSamples)) - preprocessAudioToFeaturesLocked() + let targetEndFrame = _numFramesProcessed + min(_timeline.numTentativeFrames, config.chunkLen) + let exactPaddingSamples = exactFinalizationPaddingSamples(targetEndFrame: targetEndFrame) + if exactPaddingSamples > 0 { + audioBuffer.append(contentsOf: [Float](repeating: 0, count: exactPaddingSamples)) + } + while _numFramesProcessed < targetEndFrame { + let remainingFrames = targetEndFrame - _numFramesProcessed + let targetFeatureFrames = + startFeat + + min(remainingFrames * config.subsamplingFactor, config.coreFrames) + + config.chunkRightContext * config.subsamplingFactor + preprocessAudioToFeatureTargetLocked(targetFeatureFrames: targetFeatureFrames) guard let chunk = try makeStreamingChunkLocked(), chunk.finalizedFrameCount > 0 else { - stalledPasses += 1 - continue + logger.warning( + "Sortformer finalize could not emit enough confirmed frames " + + "(\(_numFramesProcessed) / \(targetEndFrame) frames)" + ) + break } - let remainingFrames = targetEndFrame - _numFramesProcessed let finalizedFrameCount = min(remainingFrames, chunk.finalizedFrameCount) guard finalizedFrameCount > 0 else { - stalledPasses += 1 - continue + break } let finalizedPredictions = Array( @@ -556,14 +553,6 @@ public final class SortformerDiarizer: Diarizer { _numFramesProcessed += finalizedFrameCount try _timeline.addChunk(finalizedResult) lastResult = finalizedResult - stalledPasses = 0 - } - - if _numFramesProcessed < targetEndFrame { - logger.warning( - "Sortformer finalize stalled before last true frame was confirmed " - + "(\(_numFramesProcessed) / \(targetEndFrame) frames)" - ) } _timeline.finalize() @@ -758,38 +747,17 @@ public final class SortformerDiarizer: Diarizer { /// Preprocess audio into mel features - caller must hold lock private func preprocessAudioToFeaturesLocked() { + let targetFeatureFrames = startFeat + config.coreFrames + config.chunkRightContext * config.subsamplingFactor + preprocessAudioToFeatureTargetLocked(targetFeatureFrames: targetFeatureFrames) + } + + private func preprocessAudioToFeatureTargetLocked(targetFeatureFrames: Int) { guard !audioBuffer.isEmpty else { return } if audioBuffer.count < config.melWindow { return } - // Demand-Driven Optimization: - // Calculate exactly how many features we need for the next chunk - // needed = (startFeat + core + RC) - currentFeatureCount - let featLength = featureBuffer.count / config.melFeatures - let coreFrames = config.chunkLen * config.subsamplingFactor - let rightContextFrames = config.chunkRightContext * config.subsamplingFactor - - // Calculate absolute target position in feature stream - // For Chunk 0: startFeat=0. Target=104. - // For Chunk 1: startFeat=8. Target=112. - let targetEnd = startFeat + coreFrames + rightContextFrames - - let framesNeeded = targetEnd - featLength - - // If we already have enough frames, we don't strictly need to process more right now. - // However, to keep the pipeline moving smoothly, we can process if we have a full chunk buffered. - // But to strictly prioritize efficiency/latency balance as requested: - if framesNeeded <= 0 { - // We have enough features for the next chunk! - // Check if we have A LOT of audio buffered (buffer pressure)? - // If we have > 1 second of audio, maybe process it batch-wise? - // For now, lazy approach: don't process. - return - } - - // Calculate audio samples needed to produce 'framesNeeded' - // If we are appending to existing stream (featureBuffer not empty), we need stride * N. - // If featureBuffer is empty (start of stream), we need window + (N-1)*stride. + let framesNeeded = targetFeatureFrames - featLength + guard framesNeeded > 0 else { return } let samplesNeeded: Int if featureBuffer.isEmpty { @@ -798,12 +766,7 @@ public final class SortformerDiarizer: Diarizer { samplesNeeded = framesNeeded * config.melStride } - // Wait until we have enough audio to satisfy the demand - if audioBuffer.count < samplesNeeded { return } - - // We have enough audio! Process exactly what's needed (or slightly more if convenient?) - // Let's process everything we have, since we paid the initialization cost check. - // This prevents creating a backlog of unprocessed audio. + guard audioBuffer.count >= samplesNeeded else { return } let (mel, melLength, _) = melSpectrogram.computeFlatTransposed( audio: audioBuffer, @@ -839,6 +802,54 @@ public final class SortformerDiarizer: Diarizer { .resample(samples, from: sourceSampleRate) } + private func exactFinalizationPaddingSamples(targetEndFrame: Int) -> Int { + let targetFeatureFrames = max( + 0, + startFeat + max(0, targetEndFrame - _numFramesProcessed) * config.subsamplingFactor + + config.chunkRightContext * config.subsamplingFactor + ) + let currentFeatureFrames = featureBuffer.count / config.melFeatures + let additionalFeatureFramesNeeded = max(0, targetFeatureFrames - currentFeatureFrames) + guard additionalFeatureFramesNeeded > 0 else { + return 0 + } + + if producedMelFrames(forFinalizationPadding: 0) >= additionalFeatureFramesNeeded { + return 0 + } + + var upperBound = max(config.melStride, config.melWindow) + while producedMelFrames(forFinalizationPadding: upperBound) < additionalFeatureFramesNeeded { + upperBound *= 2 + } + + var lowerBound = 0 + while lowerBound < upperBound { + let middle = lowerBound + (upperBound - lowerBound) / 2 + if producedMelFrames(forFinalizationPadding: middle) >= additionalFeatureFramesNeeded { + upperBound = middle + } else { + lowerBound = middle + 1 + } + } + + return lowerBound + } + + private func producedMelFrames(forFinalizationPadding paddingSamples: Int) -> Int { + let paddedAudio: [Float] + if paddingSamples > 0 { + paddedAudio = audioBuffer + [Float](repeating: 0, count: paddingSamples) + } else { + paddedAudio = audioBuffer + } + let (_, melLength, _) = melSpectrogram.computeFlatTransposed( + audio: paddedAudio, + lastAudioSample: lastAudioSample + ) + return melLength + } + /// Get next chunk features (for testing) internal func getNextChunkFeatures() -> (mel: [Float], melLength: Int)? { lock.lock() From 898b5014fe20c5d61cdbfd94bd6b46050688878f Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Tue, 24 Mar 2026 20:22:00 -0700 Subject: [PATCH 04/12] Preserve Sortformer tentative tail during finalize --- .../SortformerDiarizerPipeline.swift | 4 ++- .../SortformerStreamingIntegrationTests.swift | 31 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift index cdd343203..fcbbdd31f 100644 --- a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift +++ b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift @@ -501,7 +501,9 @@ public final class SortformerDiarizer: Diarizer { } var lastResult: DiarizerChunkResult? + var tentativeToFlush = _timeline.numTentativeFrames if let chunk = try makeStreamingChunkLocked() { + tentativeToFlush = chunk.tentativeFrameCount let finalizedChunk = DiarizerChunkResult( startFrame: chunk.startFrame, finalizedPredictions: chunk.finalizedPredictions, @@ -514,7 +516,7 @@ public final class SortformerDiarizer: Diarizer { lastResult = finalizedChunk } - let targetEndFrame = _numFramesProcessed + min(_timeline.numTentativeFrames, config.chunkLen) + let targetEndFrame = _numFramesProcessed + min(tentativeToFlush, config.chunkLen) let exactPaddingSamples = exactFinalizationPaddingSamples(targetEndFrame: targetEndFrame) if exactPaddingSamples > 0 { audioBuffer.append(contentsOf: [Float](repeating: 0, count: exactPaddingSamples)) diff --git a/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift index 4e2a87a22..76a9ca54b 100644 --- a/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift +++ b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerStreamingIntegrationTests.swift @@ -44,4 +44,35 @@ final class SortformerStreamingIntegrationTests: XCTestCase { XCTAssertEqual(finalChunk?.tentativePredictions.count, 0) XCTAssertEqual(streamingDiarizer.timeline.numTentativeFrames, 0) } + + func testFinalizeSessionFlushesTentativeTailAfterAddAudioOnly() async throws { + let config = SortformerConfig.default + let models: SortformerModels + do { + models = try await loadModelsForTest(config: config) + } catch { + throw XCTSkip("Sortformer models unavailable in this environment: \(error)") + } + let samples = try DiarizationTestFixtures.fixtureAudio(sampleRate: config.sampleRate, limitSeconds: 4.0) + + let bufferedDiarizer = SortformerDiarizer(config: config) + bufferedDiarizer.initialize(models: models) + bufferedDiarizer.addAudio(samples) + let bufferedFinalChunk = try bufferedDiarizer.finalizeSession() + + let streamingDiarizer = SortformerDiarizer(config: config) + streamingDiarizer.initialize(models: models) + for chunk in DiarizationTestFixtures.chunk(samples, sizes: [4_800, 7_680, 9_600]) { + let _ = try streamingDiarizer.process(samples: chunk) + } + let _ = try streamingDiarizer.finalizeSession() + + XCTAssertNotNil(bufferedFinalChunk) + XCTAssertEqual(bufferedFinalChunk?.tentativeFrameCount, 0) + XCTAssertEqual(bufferedDiarizer.timeline.numTentativeFrames, 0) + XCTAssertEqual( + bufferedDiarizer.timeline.numFinalizedFrames, + streamingDiarizer.timeline.numFinalizedFrames + ) + } } From 63e81f83d3bb428b257996805b091e50307af37c Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Wed, 25 Mar 2026 10:28:22 -0700 Subject: [PATCH 05/12] Format diarizer timeline updates --- .../FluidAudio/Diarizer/DiarizerTimeline.swift | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift b/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift index a75553c82..b8a51bddd 100644 --- a/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift +++ b/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift @@ -688,6 +688,11 @@ public final class DiarizerTimeline { queue.sync { _tentativePredictions.count / speakerCapacity } } + /// Total number of frames (finalized + tentative) + public var numFrames: Int { + queue.sync { _numFinalizedFrames + _tentativePredictions.count / speakerCapacity } + } + /// Speakers in the timeline public var speakers: [Int: DiarizerSpeaker] { get { queue.sync { _speakers } } @@ -706,6 +711,7 @@ public final class DiarizerTimeline { } } + /// Whether the timeline has any segments public var hasSegments: Bool { speakers.values.contains(where: \.hasSegments) } @@ -715,6 +721,16 @@ public final class DiarizerTimeline { Float(numFinalizedFrames) * config.frameDurationSeconds } + /// Duration of tentative predictions in seconds + public var tentativeDuration: Float { + Float(numTentativeFrames) * config.frameDurationSeconds + } + + /// Duration of all predictions (finalized + tentative) in seconds + public var duration: Float { + Float(numFrames) * config.frameDurationSeconds + } + /// Maximum number of speakers public var speakerCapacity: Int { config.numSpeakers From 1d614c249aecc12ab4de7b85b54ac82cb601ac91 Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Wed, 25 Mar 2026 10:55:40 -0700 Subject: [PATCH 06/12] Format diarizer API comments --- Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift | 5 +---- .../Diarizer/Sortformer/SortformerDiarizerPipeline.swift | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift index a87470514..09e108174 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift @@ -406,10 +406,7 @@ public final class LSEENDDiarizer: Diarizer { return try processLocked() } - /// Process a chunk of audio in one call. - /// - /// Convenience method that combines `addAudio()` and `process()`. - /// + /// Add and process a chunk of audio in one call. /// - Parameters: /// - samples: Audio samples to process. /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the model rate. diff --git a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift index fcbbdd31f..0db84fa19 100644 --- a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift +++ b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift @@ -387,10 +387,7 @@ public final class SortformerDiarizer: Diarizer { } } - /// Process a chunk of audio in one call. - /// - /// Convenience method that combines `addAudio()` and `process()`. - /// + /// Add and process a chunk of audio in one call. /// - Parameters: /// - samples: Audio samples (16kHz mono) /// - sourceSampleRate: Source audio sample rate From 3b48c0b4adff76c5bcc6f97d4c8307f461ce23b1 Mon Sep 17 00:00:00 2001 From: Benjamin Lee <48599511+SGD2718@users.noreply.github.com> Date: Wed, 25 Mar 2026 12:09:32 -0700 Subject: [PATCH 07/12] Clarify comment on flushing frames at stream end Updated comment to clarify the flushing of frames. --- Documentation/Diarization/LS-EEND.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Documentation/Diarization/LS-EEND.md b/Documentation/Diarization/LS-EEND.md index b8f740e46..5d56435b3 100644 --- a/Documentation/Diarization/LS-EEND.md +++ b/Documentation/Diarization/LS-EEND.md @@ -226,7 +226,7 @@ if let update = try diarizer.process() { // Convenience: add + process in one call if let update = try diarizer.process(samples: audioChunk) { ... } -// Flush remaining committed + preview frames at end of stream +// Flush remaining frames at the end of a stream try diarizer.finalizeSession() let finalTimeline = diarizer.timeline ``` From d7ec9463f67e8d87337d7ff37ed5207a8cacf759 Mon Sep 17 00:00:00 2001 From: Benjamin Lee <48599511+SGD2718@users.noreply.github.com> Date: Wed, 25 Mar 2026 12:10:38 -0700 Subject: [PATCH 08/12] Clarify finalizeSession() functionality in documentation Updated the description of `finalizeSession()` to clarify its functionality. --- Documentation/Diarization/LS-EEND.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Documentation/Diarization/LS-EEND.md b/Documentation/Diarization/LS-EEND.md index 5d56435b3..50d74c8ab 100644 --- a/Documentation/Diarization/LS-EEND.md +++ b/Documentation/Diarization/LS-EEND.md @@ -232,7 +232,7 @@ let finalTimeline = diarizer.timeline ``` Notes: -- `finalizeSession()` pads the tail with silence when needed so the last true frame is emitted as finalized output before the timeline is finalized. +- `finalizeSession()` flushes the remaining audio by padding the end with silence. ### Speaker Enrollment From 0c2bed0a0afe60a6e21c6fccd8ee4e8e58547fa4 Mon Sep 17 00:00:00 2001 From: Benjamin Lee <48599511+SGD2718@users.noreply.github.com> Date: Wed, 25 Mar 2026 12:15:46 -0700 Subject: [PATCH 09/12] Refactor locking mechanism for property access --- .../Diarizer/LS-EEND/LSEENDDiarizerAPI.swift | 334 ++++++++---------- 1 file changed, 152 insertions(+), 182 deletions(-) diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift index 09e108174..323b606ed 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift @@ -14,44 +14,32 @@ public final class LSEENDDiarizer: Diarizer { /// Accumulated results public var timeline: DiarizerTimeline { - lock.lock() - defer { lock.unlock() } - return _timeline + lock.withLock { return _timeline } } /// Whether the processor is ready for processing public var isAvailable: Bool { - lock.lock() - defer { lock.unlock() } - return _engine != nil + lock.withLock { return _engine != nil } } /// Number of confirmed frames processed so far public var numFramesProcessed: Int { - lock.lock() - defer { lock.unlock() } - return _numFramesProcessed + lock.withLock { return _numFramesProcessed } } /// Model's target sample rate in Hz (e.g., 8000) public var targetSampleRate: Int? { - lock.lock() - defer { lock.unlock() } - return _engine?.targetSampleRate + lock.withLock { return _engine?.targetSampleRate } } /// Output frame rate in Hz (e.g., 10.0) public var modelFrameHz: Double? { - lock.lock() - defer { lock.unlock() } - return _engine?.modelFrameHz + lock.withLock { return _engine?.modelFrameHz } } /// Number of real speaker tracks (excluding boundary tracks) public var numSpeakers: Int? { - lock.lock() - defer { lock.unlock() } - return _engine?.metadata.realOutputDim + lock.withLock { return _engine?.metadata.realOutputDim } } // MARK: - Additional Properties @@ -61,30 +49,22 @@ public final class LSEENDDiarizer: Diarizer { /// Post-processing configuration public var timelineConfig: DiarizerTimelineConfig { - lock.lock() - defer { lock.unlock() } - return _timeline.config + lock.withLock { return _timeline.config } } /// Streaming latency in seconds public var streamingLatencySeconds: Double? { - lock.lock() - defer { lock.unlock() } - return _engine?.streamingLatencySeconds + lock.withLock { return _engine?.streamingLatencySeconds } } /// Total speaker slots in model output (including boundary tracks) public var decodeMaxSpeakers: Int? { - lock.lock() - defer { lock.unlock() } - return _engine?.decodeMaxSpeakers + lock.withLock { return _engine?.decodeMaxSpeakers } } /// Whether a streaming session is currently active. var hasActiveSession: Bool { - lock.lock() - defer { lock.unlock() } - return _session != nil + lock.withLock { return _session != nil } } // MARK: - Private State @@ -176,39 +156,37 @@ public final class LSEENDDiarizer: Diarizer { let engine = try LSEENDInferenceHelper(descriptor: descriptor, computeUnits: computeUnits) let melSpectrogram = Self.createMelSpectrogram(featureConfig: engine.featureConfig) - lock.lock() - defer { lock.unlock() } - - updateTimelineConfig(engine: engine) - _engine = engine - _melSpectrogram = melSpectrogram - _timeline = DiarizerTimeline(config: _timelineConfig) - _session = nil - resetBuffersLocked() - - logger.info( - "Initialized LS-EEND \(descriptor.variant.rawValue): " - + "\(engine.metadata.realOutputDim) speakers, " - + "\(String(format: "%.1f", engine.modelFrameHz)) Hz, " - + "\(String(format: "%.2f", engine.streamingLatencySeconds))s latency" - ) + lock.withLock { + updateTimelineConfig(engine: engine) + _engine = engine + _melSpectrogram = melSpectrogram + _timeline = DiarizerTimeline(config: _timelineConfig) + _session = nil + resetBuffersLocked() + + logger.info( + "Initialized LS-EEND \(descriptor.variant.rawValue): " + + "\(engine.metadata.realOutputDim) speakers, " + + "\(String(format: "%.1f", engine.modelFrameHz)) Hz, " + + "\(String(format: "%.2f", engine.streamingLatencySeconds))s latency" + ) + } } /// Initialize with a pre-loaded engine. public func initialize(engine: LSEENDInferenceHelper) { let melSpectrogram = Self.createMelSpectrogram(featureConfig: engine.featureConfig) - lock.lock() - defer { lock.unlock() } - - updateTimelineConfig(engine: engine) - _engine = engine - _melSpectrogram = melSpectrogram - _timeline = DiarizerTimeline(config: _timelineConfig) - _session = nil - resetBuffersLocked() + lock.withLock { + updateTimelineConfig(engine: engine) + _engine = engine + _melSpectrogram = melSpectrogram + _timeline = DiarizerTimeline(config: _timelineConfig) + _session = nil + resetBuffersLocked() - logger.info("Initialized LS-EEND with pre-loaded engine") + logger.info("Initialized LS-EEND with pre-loaded engine") + } } // MARK: - Speaker Priming @@ -268,103 +246,101 @@ public final class LSEENDDiarizer: Diarizer { named name: String?, overwritingAssignedSpeakerName overwriteAssignedSpeakerName: Bool ) throws -> DiarizerSpeaker? { - let description: String = name.map { "named '\($0)'" } ?? "(no name)" - - lock.lock() - defer { lock.unlock() } - - guard let engine = _engine else { - throw LSEENDError.modelPredictionFailed("LS-EEND processor not initialized. Call initialize() first.") - } - - let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) ?? samples - guard !normalized.isEmpty else { - logger.warning("Failed to enroll speaker \(description) because no speech detected") - return nil - } - - if _timeline.hasSegments { - logger.warning("Trying to enroll a speaker while timeline has segments; timeline will be reset") - } - - _timeline.reset(keepingSpeakers: true) - var occupiedIndices = Set(_timeline.speakers.keys) - _numFramesProcessed = 0 - _visibleStartFrameOffset = 0 - pendingAudio.removeAll(keepingCapacity: true) - - if _session == nil { - _session = try engine.createSession( - inputSampleRate: engine.targetSampleRate, melSpectrogram: _melSpectrogram!) - } - guard let session = _session else { - return nil - } - - let update = try session.pushAudio(normalized) - let didProcess = update.map { !$0.probabilities.isEmpty || !$0.previewProbabilities.isEmpty } ?? false - - guard didProcess else { - let minimumSeconds = engine.streamingLatencySeconds - logger.warning( - "Failed to enroll speaker \(description): not enough audio was provided. " - + "Please provide at least \(String(format: "%.2f", minimumSeconds)) seconds of speech." - ) - return nil - } - - if let update { - let numSpeakers = engine.metadata.realOutputDim - let result = DiarizerChunkResult( - startFrame: max(0, update.startFrame - _visibleStartFrameOffset), - finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers), - finalizedFrameCount: update.probabilities.rows, - tentativePredictions: flattenRowMajor(update.previewProbabilities, numSpeakers: numSpeakers), - tentativeFrameCount: update.previewProbabilities.rows - ) - _numFramesProcessed += result.finalizedFrameCount - _ = try _timeline.addChunk(result) - } - - let speaker = _timeline.speakers.values.max { $0.numSpeechFrames < $1.numSpeechFrames } - let enrolledSpeaker: DiarizerSpeaker? - if let speaker, speaker.hasSegments { - if let oldName = speaker.name { - guard overwriteAssignedSpeakerName else { + try lock.withLock { + let description: String = name.map { "named '\($0)'" } ?? "(no name)" + guard let engine = _engine else { + throw LSEENDError.modelPredictionFailed("LS-EEND processor not initialized. Call initialize() first.") + } + + let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) ?? samples + guard !normalized.isEmpty else { + logger.warning("Failed to enroll speaker \(description) because no speech detected") + return nil + } + + if _timeline.hasSegments { + logger.warning("Trying to enroll a speaker while timeline has segments; timeline will be reset") + } + + _timeline.reset(keepingSpeakers: true) + var occupiedIndices = Set(_timeline.speakers.keys) + _numFramesProcessed = 0 + _visibleStartFrameOffset = 0 + pendingAudio.removeAll(keepingCapacity: true) + + if _session == nil { + _session = try engine.createSession( + inputSampleRate: engine.targetSampleRate, melSpectrogram: _melSpectrogram!) + } + guard let session = _session else { + return nil + } + + let update = try session.pushAudio(normalized) + let didProcess = update.map { !$0.probabilities.isEmpty || !$0.previewProbabilities.isEmpty } ?? false + + guard didProcess else { + let minimumSeconds = engine.streamingLatencySeconds + logger.warning( + "Failed to enroll speaker \(description): not enough audio was provided. " + + "Please provide at least \(String(format: "%.2f", minimumSeconds)) seconds of speech." + ) + return nil + } + + if let update { + let numSpeakers = engine.metadata.realOutputDim + let result = DiarizerChunkResult( + startFrame: max(0, update.startFrame - _visibleStartFrameOffset), + finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers), + finalizedFrameCount: update.probabilities.rows, + tentativePredictions: flattenRowMajor(update.previewProbabilities, numSpeakers: numSpeakers), + tentativeFrameCount: update.previewProbabilities.rows + ) + _numFramesProcessed += result.finalizedFrameCount + _ = try _timeline.addChunk(result) + } + + let speaker = _timeline.speakers.values.max { $0.numSpeechFrames < $1.numSpeechFrames } + let enrolledSpeaker: DiarizerSpeaker? + if let speaker, speaker.hasSegments { + if let oldName = speaker.name { + guard overwriteAssignedSpeakerName else { + logger.warning( + "Failed to enroll speaker \(description): diarizer matched existing speaker '\(oldName)' " + + "at index \(speaker.index) and overwritingAssignedSpeakerName=false" + ) + _visibleStartFrameOffset = session.snapshot().probabilities.rows + _numFramesProcessed = 0 + _timeline.reset(keepingSpeakersWhere: { occupiedIndices.contains($0.index) }) + pendingAudio.removeAll(keepingCapacity: true) + return nil + } logger.warning( - "Failed to enroll speaker \(description): diarizer matched existing speaker '\(oldName)' " - + "at index \(speaker.index) and overwritingAssignedSpeakerName=false" + "Newly-enrolled speaker \(description) will overwrite the old one named \(oldName) at index \(speaker.index)" ) - _visibleStartFrameOffset = session.snapshot().probabilities.rows - _numFramesProcessed = 0 - _timeline.reset(keepingSpeakersWhere: { occupiedIndices.contains($0.index) }) - pendingAudio.removeAll(keepingCapacity: true) - return nil } - logger.warning( - "Newly-enrolled speaker \(description) will overwrite the old one named \(oldName) at index \(speaker.index)" - ) + speaker.name = name + occupiedIndices.insert(speaker.index) + enrolledSpeaker = speaker + } else { + logger.warning("Failed to enroll speaker \(description) because no speech detected") + enrolledSpeaker = nil } - speaker.name = name - occupiedIndices.insert(speaker.index) - enrolledSpeaker = speaker - } else { - logger.warning("Failed to enroll speaker \(description) because no speech detected") - enrolledSpeaker = nil + + _visibleStartFrameOffset = session.snapshot().probabilities.rows + _numFramesProcessed = 0 + _timeline.reset(keepingSpeakersWhere: { occupiedIndices.contains($0.index) }) + pendingAudio.removeAll(keepingCapacity: true) + + logger.info( + "Enrolled speaker \(description) with \(normalized.count) samples " + + "(\(String(format: "%.1f", Float(normalized.count) / Float(engine.targetSampleRate)))s), " + + "visible offset=\(_visibleStartFrameOffset)" + ) + + return enrolledSpeaker } - - _visibleStartFrameOffset = session.snapshot().probabilities.rows - _numFramesProcessed = 0 - _timeline.reset(keepingSpeakersWhere: { occupiedIndices.contains($0.index) }) - pendingAudio.removeAll(keepingCapacity: true) - - logger.info( - "Enrolled speaker \(description) with \(normalized.count) samples " - + "(\(String(format: "%.1f", Float(normalized.count) / Float(engine.targetSampleRate)))s), " - + "visible offset=\(_visibleStartFrameOffset)" - ) - - return enrolledSpeaker } // MARK: - Streaming (Diarizer Protocol) @@ -387,13 +363,12 @@ public final class LSEENDDiarizer: Diarizer { _ samples: C, sourceSampleRate: Double? = nil ) throws where C.Element == Float { - lock.lock() - defer { lock.unlock() } - - if let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) { - pendingAudio.append(contentsOf: normalized) - } else { - pendingAudio.append(contentsOf: samples) + try lock.withLock { + if let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) { + pendingAudio.append(contentsOf: normalized) + } else { + pendingAudio.append(contentsOf: samples) + } } } @@ -401,9 +376,7 @@ public final class LSEENDDiarizer: Diarizer { /// /// - Returns: New chunk result if inference produced frames, nil otherwise public func process() throws -> DiarizerTimelineUpdate? { - lock.lock() - defer { lock.unlock() } - return try processLocked() + try lock.withLock { return try processLocked() } } /// Add and process a chunk of audio in one call. @@ -415,16 +388,15 @@ public final class LSEENDDiarizer: Diarizer { samples: C, sourceSampleRate: Double? = nil ) throws -> DiarizerTimelineUpdate? where C.Element == Float { - lock.lock() - defer { lock.unlock() } + try lock.withLock { + if let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) { + pendingAudio.append(contentsOf: normalized) + } else { + pendingAudio.append(contentsOf: samples) + } - if let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) { - pendingAudio.append(contentsOf: normalized) - } else { - pendingAudio.append(contentsOf: samples) + return try processLocked() } - - return try processLocked() } /// Internal process — caller must hold lock. @@ -638,26 +610,24 @@ public final class LSEENDDiarizer: Diarizer { /// /// Preserves the loaded model. Call `initialize()` again to change models. public func reset() { - lock.lock() - defer { lock.unlock() } - - _session = nil - _timeline.reset() - resetBuffersLocked() - logger.debug("LS-EEND state reset") + lock.withLock { + _session = nil + _timeline.reset() + resetBuffersLocked() + logger.debug("LS-EEND state reset") + } } /// Clean up all resources including the loaded model. public func cleanup() { - lock.lock() - defer { lock.unlock() } - - _engine = nil - _session = nil - _melSpectrogram = nil - _timeline.reset() - resetBuffersLocked() - logger.info("LS-EEND resources cleaned up") + lock.withLock { + _engine = nil + _session = nil + _melSpectrogram = nil + _timeline.reset() + resetBuffersLocked() + logger.info("LS-EEND resources cleaned up") + } } // MARK: - LS-EEND Specific From c0b260aae701074166eeb5e5b48f0f82c0d33cc6 Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Wed, 25 Mar 2026 12:26:01 -0700 Subject: [PATCH 10/12] Optimize finalization --- .../Diarizer/LS-EEND/LSEENDDiarizerAPI.swift | 2 +- .../LS-EEND/LSEENDModelInference.swift | 11 ++- .../SortformerDiarizerPipeline.swift | 97 +++++++++++++------ .../LS-EEND/LSEENDIntegrationTests.swift | 94 +++++++++++++++++- 4 files changed, 168 insertions(+), 36 deletions(-) diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift index 323b606ed..bb5ae7493 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift @@ -327,7 +327,7 @@ public final class LSEENDDiarizer: Diarizer { logger.warning("Failed to enroll speaker \(description) because no speech detected") enrolledSpeaker = nil } - + _visibleStartFrameOffset = session.snapshot().probabilities.rows _numFramesProcessed = 0 _timeline.reset(keepingSpeakersWhere: { occupiedIndices.contains($0.index) }) diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift index 495fbb4e6..f978b9006 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDModelInference.swift @@ -516,7 +516,7 @@ public final class LSEENDStreamingSession { var committedFullLogits = LSEENDMatrix.empty(columns: engine.decodeMaxSpeakers) let targetEndFrame = Int( round(Double(totalInputSamples) / Double(max(inputSampleRate, 1)) * engine.modelFrameHz)) - let exactPaddingSamples = exactFinalizationPaddingSamples(targetEndFrame: targetEndFrame) + let exactPaddingSamples = try exactFinalizationPaddingSamples(targetEndFrame: targetEndFrame) if exactPaddingSamples > 0 { let features = try featureExtractor.pushAudio([Float](repeating: 0, count: exactPaddingSamples)) let committed = try ingestFeatures(features) @@ -539,12 +539,17 @@ public final class LSEENDStreamingSession { return try buildUpdate(committedFullLogits: committedFullLogits.appendingRows(tail), includePreview: false) } - private func exactFinalizationPaddingSamples(targetEndFrame: Int) -> Int { + private func exactFinalizationPaddingSamples(targetEndFrame: Int) throws -> Int { guard targetEndFrame > 0 else { return 0 } let stableBlockSize = engine.metadata.resolvedHopLength * engine.metadata.resolvedSubsampling - let requiredTotalSamples = targetEndFrame * stableBlockSize + let (requiredTotalSamples, overflow) = targetEndFrame.multipliedReportingOverflow(by: stableBlockSize) + guard !overflow else { + throw LSEENDError.unsupportedAudio( + "Finalization padding overflowed for \(targetEndFrame) frames at block size \(stableBlockSize)." + ) + } return max(0, requiredTotalSamples - totalInputSamples) } diff --git a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift index 0db84fa19..6187e33db 100644 --- a/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift +++ b/Sources/FluidAudio/Diarizer/Sortformer/SortformerDiarizerPipeline.swift @@ -514,7 +514,7 @@ public final class SortformerDiarizer: Diarizer { } let targetEndFrame = _numFramesProcessed + min(tentativeToFlush, config.chunkLen) - let exactPaddingSamples = exactFinalizationPaddingSamples(targetEndFrame: targetEndFrame) + let exactPaddingSamples = try exactFinalizationPaddingSamples(targetEndFrame: targetEndFrame) if exactPaddingSamples > 0 { audioBuffer.append(contentsOf: [Float](repeating: 0, count: exactPaddingSamples)) } @@ -801,52 +801,87 @@ public final class SortformerDiarizer: Diarizer { .resample(samples, from: sourceSampleRate) } - private func exactFinalizationPaddingSamples(targetEndFrame: Int) -> Int { - let targetFeatureFrames = max( - 0, - startFeat + max(0, targetEndFrame - _numFramesProcessed) * config.subsamplingFactor - + config.chunkRightContext * config.subsamplingFactor + private func exactFinalizationPaddingSamples(targetEndFrame: Int) throws -> Int { + let remainingFrames = max(0, targetEndFrame - _numFramesProcessed) + let (remainingFeatureFrames, remainingOverflow) = remainingFrames.multipliedReportingOverflow( + by: config.subsamplingFactor ) + guard !remainingOverflow else { + throw SortformerError.invalidState( + "Finalization remaining-frame expansion overflowed for \(remainingFrames) frames." + ) + } + + let (rightContextFeatureFrames, rightContextOverflow) = config.chunkRightContext.multipliedReportingOverflow( + by: config.subsamplingFactor + ) + guard !rightContextOverflow else { + throw SortformerError.invalidState( + "Finalization right-context expansion overflowed for \(config.chunkRightContext) frames." + ) + } + + let (targetWithoutContext, startOverflow) = startFeat.addingReportingOverflow(remainingFeatureFrames) + guard !startOverflow else { + throw SortformerError.invalidState( + "Finalization target feature frame calculation overflowed at startFeat=\(startFeat)." + ) + } + + let (targetFeatureFrames, contextOverflow) = targetWithoutContext.addingReportingOverflow( + rightContextFeatureFrames) + guard !contextOverflow else { + throw SortformerError.invalidState( + "Finalization target feature frame calculation overflowed after adding right context." + ) + } + let currentFeatureFrames = featureBuffer.count / config.melFeatures let additionalFeatureFramesNeeded = max(0, targetFeatureFrames - currentFeatureFrames) guard additionalFeatureFramesNeeded > 0 else { return 0 } - if producedMelFrames(forFinalizationPadding: 0) >= additionalFeatureFramesNeeded { + let framesAvailableWithoutPadding = producedMelFramesAvailable() + guard additionalFeatureFramesNeeded > framesAvailableWithoutPadding else { return 0 } - var upperBound = max(config.melStride, config.melWindow) - while producedMelFrames(forFinalizationPadding: upperBound) < additionalFeatureFramesNeeded { - upperBound *= 2 - } - - var lowerBound = 0 - while lowerBound < upperBound { - let middle = lowerBound + (upperBound - lowerBound) / 2 - if producedMelFrames(forFinalizationPadding: middle) >= additionalFeatureFramesNeeded { - upperBound = middle - } else { - lowerBound = middle + 1 + let requiredBufferedSamples: Int + if featureBuffer.isEmpty { + let (additionalSamples, overflow) = max(0, additionalFeatureFramesNeeded - 1).multipliedReportingOverflow( + by: config.melStride + ) + guard !overflow else { + throw SortformerError.invalidState( + "Finalization sample requirement overflowed for \(additionalFeatureFramesNeeded) feature frames." + ) + } + let (samples, windowOverflow) = additionalSamples.addingReportingOverflow(config.melWindow) + guard !windowOverflow else { + throw SortformerError.invalidState( + "Finalization sample requirement overflowed after adding melWindow." + ) } + requiredBufferedSamples = samples + } else { + let (samples, overflow) = additionalFeatureFramesNeeded.multipliedReportingOverflow(by: config.melStride) + guard !overflow else { + throw SortformerError.invalidState( + "Finalization sample requirement overflowed for \(additionalFeatureFramesNeeded) buffered frames." + ) + } + requiredBufferedSamples = max(config.melWindow, samples) } - return lowerBound + return max(0, requiredBufferedSamples - audioBuffer.count) } - private func producedMelFrames(forFinalizationPadding paddingSamples: Int) -> Int { - let paddedAudio: [Float] - if paddingSamples > 0 { - paddedAudio = audioBuffer + [Float](repeating: 0, count: paddingSamples) - } else { - paddedAudio = audioBuffer + private func producedMelFramesAvailable() -> Int { + guard audioBuffer.count >= config.melWindow else { + return 0 } - let (_, melLength, _) = melSpectrogram.computeFlatTransposed( - audio: paddedAudio, - lastAudioSample: lastAudioSample - ) - return melLength + return audioBuffer.count / config.melStride } /// Get next chunk features (for testing) diff --git a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift index b7aa15d24..e4c7f2d55 100644 --- a/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift +++ b/Tests/FluidAudioTests/Diarizer/LS-EEND/LSEENDIntegrationTests.swift @@ -4,13 +4,14 @@ import XCTest @testable import FluidAudio +@MainActor final class LSEENDIntegrationTests: XCTestCase { private struct ErrorStats { let maxAbs: Double let meanAbs: Double } - nonisolated(unsafe) private static var cachedEngines: [LSEENDVariant: LSEENDInferenceHelper] = [:] + private static var cachedEngines: [LSEENDVariant: LSEENDInferenceHelper] = [:] func testVariantRegistryResolvesAllExportedArtifacts() async throws { let expectedColumns: [LSEENDVariant: Int] = [ @@ -111,6 +112,54 @@ final class LSEENDIntegrationTests: XCTestCase { assertMatrixClose(snapshot.fullLogits, offline.fullLogits, maxAbs: 1e-5, meanAbs: 1e-6) } + func testStreamingFinalizationUsesExactPaddingForTailFlush() async throws { + let engine = try await makeEngine(variant: .dihard3) + let sampleRate = engine.targetSampleRate + let stableBlockSize = engine.metadata.resolvedHopLength * engine.metadata.resolvedSubsampling + let rawSamples = try DiarizationTestFixtures.fixtureAudio(sampleRate: sampleRate, limitSeconds: 6.0) + let sampleCount = stableBlockSize * 2 + stableBlockSize / 2 + 37 + let samples = Array(rawSamples.prefix(sampleCount)) + let expected = try engine.infer(samples: samples, sampleRate: sampleRate) + + let targetEndFrame = Int(round(Double(samples.count) / Double(sampleRate) * engine.modelFrameHz)) + let expectedPaddingSamples = max(0, targetEndFrame * stableBlockSize - samples.count) + + XCTAssertGreaterThan(expectedPaddingSamples, 0) + XCTAssertEqual((samples.count + expectedPaddingSamples) % stableBlockSize, 0) + + let session = try engine.createSession(inputSampleRate: sampleRate) + _ = try session.pushAudio(samples) + let finalUpdate = try session.finalize() + + XCTAssertNotNil(finalUpdate) + XCTAssertEqual(finalUpdate?.previewLogits.rows, 0) + XCTAssertEqual(finalUpdate?.previewProbabilities.rows, 0) + XCTAssertNil(try session.finalize()) + + let snapshot = session.snapshot() + XCTAssertEqual(snapshot.probabilities.rows, expected.probabilities.rows) + XCTAssertEqual(snapshot.logits.rows, expected.logits.rows) + assertMatrixClose(snapshot.logits, expected.logits, maxAbs: 1e-5, meanAbs: 1e-6) + assertMatrixClose(snapshot.probabilities, expected.probabilities, maxAbs: 1e-5, meanAbs: 1e-6) + assertMatrixClose(snapshot.fullLogits, expected.fullLogits, maxAbs: 1e-5, meanAbs: 1e-6) + assertMatrixClose(snapshot.fullProbabilities, expected.fullProbabilities, maxAbs: 1e-5, meanAbs: 1e-6) + } + + func testEmptyAudioFinalizationProducesNoOutput() async throws { + let engine = try await makeEngine(variant: .dihard3) + let session = try engine.createSession(inputSampleRate: engine.targetSampleRate) + + XCTAssertNil(try session.finalize()) + XCTAssertNil(try session.finalize()) + + let snapshot = session.snapshot() + XCTAssertEqual(snapshot.logits.rows, 0) + XCTAssertEqual(snapshot.probabilities.rows, 0) + XCTAssertEqual(snapshot.fullLogits.rows, 0) + XCTAssertEqual(snapshot.fullProbabilities.rows, 0) + XCTAssertEqual(snapshot.durationSeconds, 0) + } + func testStreamingSimulationMatchesOfflineInferenceAndReportsMonotonicProgress() async throws { let engine = try await makeEngine(variant: .dihard3) let fixtureURL = try DiarizationTestFixtures.fixtureAudioFileURL() @@ -250,6 +299,49 @@ final class LSEENDIntegrationTests: XCTestCase { XCTAssertFalse(diarizer.hasActiveSession) } + func testDiarizerTimelineCountAndDurationPropertiesReflectFrames() throws { + let frameDurationSeconds: Float = 0.25 + let timeline = DiarizerTimeline(config: .default(numSpeakers: 3, frameDurationSeconds: frameDurationSeconds)) + + XCTAssertEqual(timeline.numFinalizedFrames, 0) + XCTAssertEqual(timeline.numTentativeFrames, 0) + XCTAssertEqual(timeline.numFrames, 0) + XCTAssertEqual(timeline.finalizedDuration, 0) + XCTAssertEqual(timeline.tentativeDuration, 0) + XCTAssertEqual(timeline.duration, 0) + + let chunk = DiarizerChunkResult( + startFrame: 5, + finalizedPredictions: [ + 0.10, 0.20, 0.30, + 0.40, 0.50, 0.60, + ], + finalizedFrameCount: 2, + tentativePredictions: [ + 0.70, 0.80, 0.90, + ], + tentativeFrameCount: 1 + ) + + try timeline.addChunk(chunk) + + XCTAssertEqual(timeline.numFinalizedFrames, 2) + XCTAssertEqual(timeline.numTentativeFrames, 1) + XCTAssertEqual(timeline.numFrames, 3) + XCTAssertEqual(timeline.finalizedDuration, 0.5, accuracy: 1e-6) + XCTAssertEqual(timeline.tentativeDuration, 0.25, accuracy: 1e-6) + XCTAssertEqual(timeline.duration, 0.75, accuracy: 1e-6) + + timeline.finalize() + + XCTAssertEqual(timeline.numFinalizedFrames, 3) + XCTAssertEqual(timeline.numTentativeFrames, 0) + XCTAssertEqual(timeline.numFrames, 3) + XCTAssertEqual(timeline.finalizedDuration, 0.75, accuracy: 1e-6) + XCTAssertEqual(timeline.tentativeDuration, 0, accuracy: 1e-6) + XCTAssertEqual(timeline.duration, 0.75, accuracy: 1e-6) + } + private func makeEngine(variant: LSEENDVariant) async throws -> LSEENDInferenceHelper { if let cached = Self.cachedEngines[variant] { return cached From 64c13f4c20ff89cc787f2e381567e2dd6ebf7939 Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Wed, 25 Mar 2026 14:30:59 -0700 Subject: [PATCH 11/12] Add diarizer segment confidence --- .../Diarizer/DiarizerTimeline.swift | 68 ++++++++++++++++--- .../Sortformer/SortformerTimelineTests.swift | 60 ++++++++++++++++ 2 files changed, 117 insertions(+), 11 deletions(-) diff --git a/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift b/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift index b8a51bddd..ef11047f4 100644 --- a/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift +++ b/Sources/FluidAudio/Diarizer/DiarizerTimeline.swift @@ -506,6 +506,9 @@ public struct DiarizerSegment: Sendable, Identifiable, Comparable, Equatable { /// Duration of one frame in seconds public let frameDurationSeconds: Float + /// Confidence in this speech segment (average speech probability from the diarizer) + public var confidence: Float = 0.0 + /// Start time in seconds public var startTime: Float { Float(startFrame) * frameDurationSeconds } @@ -523,7 +526,8 @@ public struct DiarizerSegment: Sendable, Identifiable, Comparable, Equatable { startFrame: Int, endFrame: Int, finalized: Bool = true, - frameDurationSeconds: Float + frameDurationSeconds: Float, + confidence: Float = 0 ) { self.id = UUID() self.speakerIndex = speakerIndex @@ -531,6 +535,7 @@ public struct DiarizerSegment: Sendable, Identifiable, Comparable, Equatable { self.endFrame = endFrame self.isFinalized = finalized self.frameDurationSeconds = frameDurationSeconds + self.confidence = confidence } public init( @@ -538,7 +543,8 @@ public struct DiarizerSegment: Sendable, Identifiable, Comparable, Equatable { startTime: Float, endTime: Float, finalized: Bool = true, - frameDurationSeconds: Float + frameDurationSeconds: Float, + confidence: Float = 0 ) { self.id = UUID() self.speakerIndex = speakerIndex @@ -546,6 +552,7 @@ public struct DiarizerSegment: Sendable, Identifiable, Comparable, Equatable { self.endFrame = Int(round(endTime / frameDurationSeconds)) self.isFinalized = finalized self.frameDurationSeconds = frameDurationSeconds + self.confidence = confidence } /// Check if this overlaps with another segment @@ -639,6 +646,13 @@ public struct DiarizerChunkResult: Sendable { /// Generalizes `SortformerTimeline` for any frame-based diarizer. Works with /// both Sortformer (fixed 4 speakers) and LS-EEND (variable speaker count). public final class DiarizerTimeline { + private struct ClosedSegmentStats { + var start: Int + var end: Int + var activitySum: Float + var activeFrameCount: Int + } + public enum KeptOnReset { case nothing case namedSpeakers @@ -650,15 +664,21 @@ public final class DiarizerTimeline { private struct StreamingState { var startFrame: Int var isSpeaking: Bool - var lastSegment: (start: Int, end: Int) + var activitySum: Float + var activeFrameCount: Int + var lastSegment: ClosedSegmentStats? init( startFrame: Int = 0, isSpeaking: Bool = false, - lastSegment: (start: Int, end: Int) = (-1, -1) + activitySum: Float = 0, + activeFrameCount: Int = 0, + lastSegment: ClosedSegmentStats? = nil ) { self.startFrame = startFrame self.isSpeaking = isSpeaking + self.activitySum = activitySum + self.activeFrameCount = activeFrameCount self.lastSegment = lastSegment } } @@ -1119,42 +1139,64 @@ public final class DiarizerTimeline { var start = state.startFrame var speaking = state.isSpeaking + var activitySum = state.activitySum + var activeFrameCount = state.activeFrameCount var lastSegment = state.lastSegment var wasLastSegmentFinal = isFinalized for i in 0..= offset { + if activity >= offset { + activitySum += activity + activeFrameCount += 1 continue } speaking = false let end = frameOffset + i + padOffset - guard end - start > minFramesOn else { continue } + guard end - start > minFramesOn else { + activitySum = 0 + activeFrameCount = 0 + continue + } wasLastSegmentFinal = isFinalized && (end < tentativeStartFrame) + let confidence = activeFrameCount > 0 ? (activitySum / Float(activeFrameCount)) : 0 let newSegment = DiarizerSegment( speakerIndex: speakerIndex, startFrame: start, endFrame: end, finalized: wasLastSegmentFinal, - frameDurationSeconds: frameDuration + frameDurationSeconds: frameDuration, + confidence: confidence ) provideSpeaker(forSlot: speakerIndex).append(newSegment) - lastSegment = (start, end) + lastSegment = ClosedSegmentStats( + start: start, + end: end, + activitySum: activitySum, + activeFrameCount: activeFrameCount + ) + activitySum = 0 + activeFrameCount = 0 - } else if predictions[index] > onset { + } else if activity > onset { start = max(0, frameOffset + i - padOnset) speaking = true + activitySum = activity + activeFrameCount = 1 - if start - lastSegment.end <= minFramesOff { + if let lastSegment, start - lastSegment.end <= minFramesOff { start = lastSegment.start + activitySum += lastSegment.activitySum + activeFrameCount += lastSegment.activeFrameCount _speakers[speakerIndex]?.popLast(fromFinalized: wasLastSegmentFinal) } } @@ -1163,18 +1205,22 @@ public final class DiarizerTimeline { if isFinalized { states[speakerIndex].startFrame = start states[speakerIndex].isSpeaking = speaking + states[speakerIndex].activitySum = activitySum + states[speakerIndex].activeFrameCount = activeFrameCount states[speakerIndex].lastSegment = lastSegment } if addTrailingTentative { let end = frameOffset + numFrames + padOffset if speaking && (end > start) { + let confidence = activeFrameCount > 0 ? (activitySum / Float(activeFrameCount)) : 0 let newSegment = DiarizerSegment( speakerIndex: speakerIndex, startFrame: start, endFrame: end, finalized: false, - frameDurationSeconds: frameDuration + frameDurationSeconds: frameDuration, + confidence: confidence ) provideSpeaker(forSlot: speakerIndex).appendTentative(newSegment) } diff --git a/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerTimelineTests.swift b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerTimelineTests.swift index 62da8b5a4..bb07c5b7c 100644 --- a/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerTimelineTests.swift +++ b/Tests/FluidAudioTests/Diarizer/Sortformer/SortformerTimelineTests.swift @@ -136,6 +136,66 @@ final class SortformerTimelineTests: XCTestCase { XCTAssertTrue(timeline.tentativePredictions.isEmpty) } + func testSegmentConfidenceExcludesPaddingFrames() throws { + let config = DiarizerTimelineConfig( + numSpeakers: 1, + frameDurationSeconds: 0.08, + onsetThreshold: 0.5, + offsetThreshold: 0.5, + onsetPadFrames: 1, + offsetPadFrames: 2, + minFramesOn: 0, + minFramesOff: 0 + ) + let timeline = DiarizerTimeline(config: config) + let predictions: [Float] = [0.0, 0.8, 0.6, 0.0] + + try timeline.addChunk( + DiarizerChunkResult( + startFrame: 0, + finalizedPredictions: predictions, + finalizedFrameCount: predictions.count + ) + ) + + timeline.finalize() + + let segment = try XCTUnwrap(timeline.speakers[0]?.finalizedSegments.first) + XCTAssertEqual(segment.startFrame, 0) + XCTAssertEqual(segment.endFrame, 5) + XCTAssertEqual(segment.confidence, 0.7, accuracy: 1e-6) + } + + func testSegmentConfidenceExcludesBridgedGapFrames() throws { + let config = DiarizerTimelineConfig( + numSpeakers: 1, + frameDurationSeconds: 0.08, + onsetThreshold: 0.5, + offsetThreshold: 0.5, + onsetPadFrames: 0, + offsetPadFrames: 0, + minFramesOn: 0, + minFramesOff: 1 + ) + let timeline = DiarizerTimeline(config: config) + let predictions: [Float] = [0.9, 0.0, 0.7, 0.7, 0.0] + + try timeline.addChunk( + DiarizerChunkResult( + startFrame: 0, + finalizedPredictions: predictions, + finalizedFrameCount: predictions.count + ) + ) + + timeline.finalize() + + let segment = try XCTUnwrap(timeline.speakers[0]?.finalizedSegments.first) + XCTAssertEqual(segment.startFrame, 0) + XCTAssertEqual(segment.endFrame, 4) + XCTAssertEqual(segment.confidence, (0.9 + 0.7 + 0.7) / 3.0, accuracy: 1e-6) + } + // MARK: - Probability Access func testProbabilityAccess() throws { From 223519f3f513d821c13e0dd16f29e555cf1b0f61 Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Wed, 25 Mar 2026 15:19:09 -0700 Subject: [PATCH 12/12] Format sources and tests --- .../Diarizer/LS-EEND/LSEENDDiarizerAPI.swift | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift index bb5ae7493..b74affa83 100644 --- a/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift +++ b/Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizerAPI.swift @@ -156,14 +156,14 @@ public final class LSEENDDiarizer: Diarizer { let engine = try LSEENDInferenceHelper(descriptor: descriptor, computeUnits: computeUnits) let melSpectrogram = Self.createMelSpectrogram(featureConfig: engine.featureConfig) - lock.withLock { + lock.withLock { updateTimelineConfig(engine: engine) _engine = engine _melSpectrogram = melSpectrogram _timeline = DiarizerTimeline(config: _timelineConfig) _session = nil resetBuffersLocked() - + logger.info( "Initialized LS-EEND \(descriptor.variant.rawValue): " + "\(engine.metadata.realOutputDim) speakers, " @@ -177,7 +177,7 @@ public final class LSEENDDiarizer: Diarizer { public func initialize(engine: LSEENDInferenceHelper) { let melSpectrogram = Self.createMelSpectrogram(featureConfig: engine.featureConfig) - lock.withLock { + lock.withLock { updateTimelineConfig(engine: engine) _engine = engine _melSpectrogram = melSpectrogram @@ -246,28 +246,28 @@ public final class LSEENDDiarizer: Diarizer { named name: String?, overwritingAssignedSpeakerName overwriteAssignedSpeakerName: Bool ) throws -> DiarizerSpeaker? { - try lock.withLock { + try lock.withLock { let description: String = name.map { "named '\($0)'" } ?? "(no name)" guard let engine = _engine else { throw LSEENDError.modelPredictionFailed("LS-EEND processor not initialized. Call initialize() first.") } - + let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) ?? samples guard !normalized.isEmpty else { logger.warning("Failed to enroll speaker \(description) because no speech detected") return nil } - + if _timeline.hasSegments { logger.warning("Trying to enroll a speaker while timeline has segments; timeline will be reset") } - + _timeline.reset(keepingSpeakers: true) var occupiedIndices = Set(_timeline.speakers.keys) _numFramesProcessed = 0 _visibleStartFrameOffset = 0 pendingAudio.removeAll(keepingCapacity: true) - + if _session == nil { _session = try engine.createSession( inputSampleRate: engine.targetSampleRate, melSpectrogram: _melSpectrogram!) @@ -275,10 +275,10 @@ public final class LSEENDDiarizer: Diarizer { guard let session = _session else { return nil } - + let update = try session.pushAudio(normalized) let didProcess = update.map { !$0.probabilities.isEmpty || !$0.previewProbabilities.isEmpty } ?? false - + guard didProcess else { let minimumSeconds = engine.streamingLatencySeconds logger.warning( @@ -287,7 +287,7 @@ public final class LSEENDDiarizer: Diarizer { ) return nil } - + if let update { let numSpeakers = engine.metadata.realOutputDim let result = DiarizerChunkResult( @@ -300,7 +300,7 @@ public final class LSEENDDiarizer: Diarizer { _numFramesProcessed += result.finalizedFrameCount _ = try _timeline.addChunk(result) } - + let speaker = _timeline.speakers.values.max { $0.numSpeechFrames < $1.numSpeechFrames } let enrolledSpeaker: DiarizerSpeaker? if let speaker, speaker.hasSegments { @@ -332,13 +332,13 @@ public final class LSEENDDiarizer: Diarizer { _numFramesProcessed = 0 _timeline.reset(keepingSpeakersWhere: { occupiedIndices.contains($0.index) }) pendingAudio.removeAll(keepingCapacity: true) - + logger.info( "Enrolled speaker \(description) with \(normalized.count) samples " + "(\(String(format: "%.1f", Float(normalized.count) / Float(engine.targetSampleRate)))s), " + "visible offset=\(_visibleStartFrameOffset)" ) - + return enrolledSpeaker } } @@ -363,7 +363,7 @@ public final class LSEENDDiarizer: Diarizer { _ samples: C, sourceSampleRate: Double? = nil ) throws where C.Element == Float { - try lock.withLock { + try lock.withLock { if let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) { pendingAudio.append(contentsOf: normalized) } else { @@ -388,7 +388,7 @@ public final class LSEENDDiarizer: Diarizer { samples: C, sourceSampleRate: Double? = nil ) throws -> DiarizerTimelineUpdate? where C.Element == Float { - try lock.withLock { + try lock.withLock { if let normalized = try normalizeSamplesLocked(samples, sourceSampleRate: sourceSampleRate) { pendingAudio.append(contentsOf: normalized) } else { @@ -610,7 +610,7 @@ public final class LSEENDDiarizer: Diarizer { /// /// Preserves the loaded model. Call `initialize()` again to change models. public func reset() { - lock.withLock { + lock.withLock { _session = nil _timeline.reset() resetBuffersLocked()