diff --git a/openless-all/app/src-tauri/src/asr/whisper.rs b/openless-all/app/src-tauri/src/asr/whisper.rs index 6cfc2e52..6659e4c5 100644 --- a/openless-all/app/src-tauri/src/asr/whisper.rs +++ b/openless-all/app/src-tauri/src/asr/whisper.rs @@ -7,6 +7,9 @@ use parking_lot::Mutex; use crate::asr::wav::encode_wav_16k_mono; use crate::asr::RawTranscript; +const PCM_SAMPLE_RATE_HZ: u64 = 16_000; +const PCM_BYTES_PER_SAMPLE: usize = 2; + /// Whisper の `prompt` パラメータの安全側上限(文字数)。 /// /// OpenAI / Groq の Audio Transcriptions API は `prompt` を 244 トークンまで @@ -25,16 +28,25 @@ pub struct WhisperBatchASR { /// 任意のプロンプト(語彙ヒント等)。空文字や空白のみは送信しない。 /// `None` = プロンプト無し(既存挙動)。 prompt: Option, + /// OpenAI 互換でもファイル長に上限がある provider 用。None は従来通り一括送信。 + max_chunk_duration_ms: Option, buffer: Mutex>, } impl WhisperBatchASR { - pub fn new(api_key: String, base_url: String, model: String, prompt: Option) -> Self { + pub fn new( + api_key: String, + base_url: String, + model: String, + prompt: Option, + max_chunk_duration_ms: Option, + ) -> Self { Self { api_key, base_url, model, prompt, + max_chunk_duration_ms, buffer: Mutex::new(Vec::new()), } } @@ -65,13 +77,25 @@ impl WhisperBatchASR { } async fn transcribe_inner(&self, pcm: &[u8]) -> Result { - // 16 kHz mono 16-bit: 2 bytes per sample. - let duration_ms = (pcm.len() as u64 / 2) * 1000 / 16_000; - if self.api_key.is_empty() { anyhow::bail!("Whisper API key missing"); } + let duration_ms = pcm_duration_ms(pcm); + let chunks = split_pcm_by_duration(pcm, self.max_chunk_duration_ms); + let mut texts = Vec::with_capacity(chunks.len()); + + for chunk in chunks { + texts.push(self.transcribe_chunk(chunk).await?); + } + + Ok(RawTranscript { + text: join_transcript_chunks(&texts), + duration_ms, + }) + } + + async fn transcribe_chunk(&self, pcm: &[u8]) -> Result { let samples: Vec = pcm .chunks_exact(2) .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]])) @@ -114,9 +138,7 @@ impl WhisperBatchASR { } let json: serde_json::Value = resp.json().await.context("parse Whisper response")?; - let text = json["text"].as_str().unwrap_or("").trim().to_string(); - - Ok(RawTranscript { text, duration_ms }) + Ok(json["text"].as_str().unwrap_or("").trim().to_string()) } pub fn cancel(&self) { @@ -130,6 +152,138 @@ impl crate::recorder::AudioConsumer for WhisperBatchASR { } } +fn pcm_duration_ms(pcm: &[u8]) -> u64 { + (pcm.len() as u64 / PCM_BYTES_PER_SAMPLE as u64) * 1000 / PCM_SAMPLE_RATE_HZ +} + +fn split_pcm_by_duration(pcm: &[u8], max_chunk_duration_ms: Option) -> Vec<&[u8]> { + let Some(max_chunk_duration_ms) = max_chunk_duration_ms else { + return vec![pcm]; + }; + if max_chunk_duration_ms == 0 { + return vec![pcm]; + } + + let samples_per_chunk = PCM_SAMPLE_RATE_HZ * max_chunk_duration_ms / 1000; + let bytes_per_chunk = samples_per_chunk as usize * PCM_BYTES_PER_SAMPLE; + if bytes_per_chunk == 0 || pcm.len() <= bytes_per_chunk { + return vec![pcm]; + } + + pcm.chunks(bytes_per_chunk).collect() +} + +fn join_transcript_chunks(chunks: &[String]) -> String { + let mut joined = String::new(); + for chunk in chunks.iter().map(|chunk| chunk.trim()) { + if chunk.is_empty() { + continue; + } + if needs_chunk_separator(&joined, chunk) { + joined.push(' '); + } + joined.push_str(chunk); + } + joined +} + +fn needs_chunk_separator(current: &str, next: &str) -> bool { + let Some(prev) = current.chars().last() else { + return false; + }; + let Some(first) = next.chars().next() else { + return false; + }; + + if is_closing_punctuation(first) || is_opening_punctuation(prev) { + return false; + } + if is_cjk(prev) && (is_cjk(first) || is_opening_punctuation(first)) { + return false; + } + if is_cjk(first) && is_closing_punctuation(prev) { + return false; + } + if is_cjk_punctuation(prev) && is_cjk(first) { + return false; + } + true +} + +fn is_opening_punctuation(ch: char) -> bool { + matches!( + ch, + '(' | '[' | '{' | '"' | '\'' | '(' | '「' | '『' | '《' | '“' | '‘' + ) +} + +fn is_closing_punctuation(ch: char) -> bool { + matches!( + ch, + ',' | '.' + | '!' + | '?' + | ':' + | ';' + | ')' + | ']' + | '}' + | '"' + | '\'' + | ',' + | '。' + | '、' + | '!' + | '?' + | ':' + | ';' + | ')' + | '」' + | '』' + | '》' + | '”' + | '’' + | '…' + ) +} + +fn is_cjk_punctuation(ch: char) -> bool { + matches!( + ch, + ',' | '。' + | '、' + | '!' + | '?' + | ':' + | ';' + | '(' + | ')' + | '「' + | '」' + | '『' + | '』' + | '《' + | '》' + | '“' + | '”' + | '‘' + | '’' + | '…' + | '—' + ) +} + +fn is_cjk(ch: char) -> bool { + matches!( + ch as u32, + 0x3400..=0x4DBF + | 0x4E00..=0x9FFF + | 0x3040..=0x30FF + | 0xAC00..=0xD7AF + | 0xF900..=0xFAFF + ) +} + /// 用户辞書の有効フレーズから Whisper の `prompt` パラメータを組み立てる。 /// /// Whisper は `prompt` で語彙ヒント / スタイル文脈を渡せる:固有名詞・専門 @@ -185,6 +339,11 @@ pub fn build_prompt_from_phrases(phrases: &[String]) -> Option { #[cfg(test)] mod tests { use super::*; + use crate::recorder::AudioConsumer; + use std::io::{Read, Write}; + use std::net::{TcpListener, TcpStream}; + use std::thread; + use std::time::{Duration, Instant}; #[test] fn build_prompt_returns_none_for_empty_input() { @@ -279,4 +438,188 @@ mod tests { // 100 件 × 8 文字以上は確実に予算超過 → 末尾は入らない assert!(!prompt.contains("entry099")); } + + #[test] + fn split_pcm_by_duration_keeps_default_as_single_chunk() { + let pcm = vec![0u8; 96_000]; + assert_eq!(split_pcm_by_duration(&pcm, None), vec![pcm.as_slice()]); + } + + #[test] + fn split_pcm_by_duration_uses_sample_boundaries() { + let pcm = vec![0u8; 32_000 * 65]; + let chunks = split_pcm_by_duration(&pcm, Some(30_000)); + + assert_eq!(chunks.len(), 3); + assert_eq!(chunks[0].len(), 32_000 * 30); + assert_eq!(chunks[1].len(), 32_000 * 30); + assert_eq!(chunks[2].len(), 32_000 * 5); + } + + #[test] + fn split_pcm_by_duration_zero_limit_falls_back_to_single_chunk() { + let pcm = vec![0u8; 96_000]; + assert_eq!(split_pcm_by_duration(&pcm, Some(0)), vec![pcm.as_slice()]); + } + + #[test] + fn join_transcript_chunks_skips_empty_chunks() { + let chunks = vec![" hello ".to_string(), "".to_string(), "world".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "hello world"); + } + + #[test] + fn join_transcript_chunks_keeps_cjk_together() { + let chunks = vec!["你好".to_string(), "世界".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "你好世界"); + } + + #[test] + fn join_transcript_chunks_separates_mixed_script_boundaries() { + let chunks = vec!["中文".to_string(), "English".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "中文 English"); + + let chunks = vec!["OpenLess".to_string(), "中文".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "OpenLess 中文"); + } + + #[test] + fn join_transcript_chunks_handles_punctuation_boundaries() { + let chunks = vec!["hello,".to_string(), "world".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "hello, world"); + + let chunks = vec!["hello".to_string(), ",".to_string(), "world".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "hello, world"); + + let chunks = vec!["foo.".to_string(), "bar".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "foo. bar"); + + let chunks = vec!["(".to_string(), "hello".to_string(), ")".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "(hello)"); + } + + #[test] + fn join_transcript_chunks_handles_cjk_punctuation_boundaries() { + let chunks = vec!["你好".to_string(), ",世界".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "你好,世界"); + + let chunks = vec!["中文".to_string(), "。".to_string(), "下一句".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "中文。下一句"); + + let chunks = vec!["他说".to_string(), ":".to_string(), "你好".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "他说:你好"); + + let chunks = vec!["中文。".to_string(), "OpenAI".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "中文。 OpenAI"); + + let chunks = vec!["「".to_string(), "中文".to_string(), "」".to_string()]; + assert_eq!(join_transcript_chunks(&chunks), "「中文」"); + } + + #[tokio::test] + async fn transcribe_posts_single_request_without_chunk_limit() { + let (base_url, server) = start_whisper_test_server(vec!["one"]); + let asr = + WhisperBatchASR::new("key".to_string(), base_url, "model".to_string(), None, None); + let pcm = vec![0u8; 32_000 * 65]; + asr.consume_pcm_chunk(&pcm); + + let transcript = asr.transcribe().await.unwrap(); + + assert_eq!(transcript.text, "one"); + assert_eq!(transcript.duration_ms, 65_000); + server.join().unwrap(); + } + + #[tokio::test] + async fn transcribe_splits_requests_when_chunk_limit_is_set() { + let (base_url, server) = start_whisper_test_server(vec!["你好", "world", "尾"]); + let asr = WhisperBatchASR::new( + "key".to_string(), + base_url, + "model".to_string(), + None, + Some(30_000), + ); + let pcm = vec![0u8; 32_000 * 65]; + asr.consume_pcm_chunk(&pcm); + + let transcript = asr.transcribe().await.unwrap(); + + assert_eq!(transcript.text, "你好 world 尾"); + assert_eq!(transcript.duration_ms, 65_000); + server.join().unwrap(); + } + + fn start_whisper_test_server(texts: Vec<&'static str>) -> (String, thread::JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + listener.set_nonblocking(true).unwrap(); + let addr = listener.local_addr().unwrap(); + let server = thread::spawn(move || { + let deadline = Instant::now() + Duration::from_secs(5); + for text in texts { + let mut stream = loop { + match listener.accept() { + Ok((stream, _)) => break stream, + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => { + assert!( + Instant::now() < deadline, + "timed out waiting for ASR test request" + ); + thread::sleep(Duration::from_millis(10)); + } + Err(err) => panic!("accept ASR test request failed: {err}"), + } + }; + stream + .set_read_timeout(Some(Duration::from_secs(5))) + .unwrap(); + let request = read_http_request(&mut stream); + let request_text = String::from_utf8_lossy(&request); + assert!(request_text.starts_with("POST /audio/transcriptions HTTP/1.1")); + assert!(request_text.contains("authorization: Bearer key")); + assert!(request_text.contains("model")); + write_json_response(&mut stream, &format!(r#"{{"text":"{}"}}"#, text)); + } + }); + (format!("http://{}", addr), server) + } + + fn read_http_request(stream: &mut TcpStream) -> Vec { + let mut buf = [0u8; 8192]; + let mut request = Vec::new(); + loop { + let n = stream.read(&mut buf).unwrap(); + if n == 0 { + break; + } + request.extend_from_slice(&buf[..n]); + let Some(header_end) = request.windows(4).position(|w| w == b"\r\n\r\n") else { + continue; + }; + let header_text = String::from_utf8_lossy(&request[..header_end + 4]); + let content_length = header_text + .lines() + .find_map(|line| { + line.strip_prefix("content-length:") + .or_else(|| line.strip_prefix("Content-Length:")) + }) + .and_then(|value| value.trim().parse::().ok()) + .unwrap_or(0); + if request.len() >= header_end + 4 + content_length { + break; + } + } + request + } + + fn write_json_response(stream: &mut TcpStream, body: &str) { + write!( + stream, + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ) + .unwrap(); + } } diff --git a/openless-all/app/src-tauri/src/coordinator.rs b/openless-all/app/src-tauri/src/coordinator.rs index e137b2cb..5b5c81d4 100644 --- a/openless-all/app/src-tauri/src/coordinator.rs +++ b/openless-all/app/src-tauri/src/coordinator.rs @@ -3999,6 +3999,7 @@ mod tests { "http://localhost".to_string(), "model".to_string(), None, + None, )); *coordinator.inner.asr.lock() = Some(SessionResource::new( session_id(2), diff --git a/openless-all/app/src-tauri/src/coordinator/dictation.rs b/openless-all/app/src-tauri/src/coordinator/dictation.rs index 2c138897..65c45b0c 100644 --- a/openless-all/app/src-tauri/src/coordinator/dictation.rs +++ b/openless-all/app/src-tauri/src/coordinator/dictation.rs @@ -741,6 +741,7 @@ pub(super) async fn begin_session(inner: &Arc) -> Result<(), String> { base_url, model, whisper_prompt, + batch_asr_chunk_limit_ms(&active_asr), )); store_asr_for_session( inner, @@ -830,6 +831,13 @@ pub(super) async fn begin_session(inner: &Arc) -> Result<(), String> { Ok(()) } +fn batch_asr_chunk_limit_ms(provider_id: &str) -> Option { + match provider_id { + "zhipu" => Some(30_000), + _ => None, + } +} + pub(super) async fn start_recorder_for_starting( inner: &Arc, session_id: SessionId, @@ -1758,8 +1766,9 @@ fn append_typed_prefix(target: &mut String, delta: &str, typed_chars: usize) -> #[cfg(test)] mod tests { use super::{ - append_typed_prefix, default_done_message, drain_streaming_insert_deltas_with, - finalize_polished_text, flush_streaming_insert_buffer_with, streaming_insert_eligible, + append_typed_prefix, batch_asr_chunk_limit_ms, default_done_message, + drain_streaming_insert_deltas_with, finalize_polished_text, + flush_streaming_insert_buffer_with, streaming_insert_eligible, }; use crate::types::{ChineseScriptPreference, CorrectionRule, InsertStatus, PolishMode}; @@ -1855,6 +1864,15 @@ mod tests { )); } + #[test] + fn batch_asr_chunk_limit_applies_only_to_zhipu() { + assert_eq!(batch_asr_chunk_limit_ms("zhipu"), Some(30_000)); + assert_eq!(batch_asr_chunk_limit_ms("whisper"), None); + assert_eq!(batch_asr_chunk_limit_ms("siliconflow"), None); + assert_eq!(batch_asr_chunk_limit_ms("groq"), None); + assert_eq!(batch_asr_chunk_limit_ms("volcengine"), None); + } + #[test] fn default_done_message_works_correctly() { assert_eq!(