diff --git a/Cargo.lock b/Cargo.lock index d00bc5f..ccea46f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -153,6 +153,7 @@ name = "aish-pty" version = "0.2.0" dependencies = [ "aish-core", + "aish-i18n", "chrono", "futures", "libc", diff --git a/crates/aish-i18n/locales/en-US.yaml b/crates/aish-i18n/locales/en-US.yaml index 924d26b..1141187 100644 --- a/crates/aish-i18n/locales/en-US.yaml +++ b/crates/aish-i18n/locales/en-US.yaml @@ -331,7 +331,7 @@ shell: plan_review_hint: "Use /plan to start a new planning session." # Thinking animation - thinking_time: "Thinking: {time:.1}s" + thinking_time: "Thinking: {time}s" # Analyzing environment analyzing_environment: "Analyzing environment..." @@ -384,6 +384,21 @@ shell: command_cancelled: "Command cancelled" + session: + tool_bash: "🔧 Using tool: bash ({command})" + confirm_execute: "⚠ Execute? [Y/n]" + ask_user: + custom_input_label: "Enter custom answer" + default_hint: "[default: {default}]" + help_with_cancel: "Esc to cancel, Enter to select" + help_no_cancel: "Enter to select" + min_length_error: "Answer too short (minimum {min} characters)" + tool_banner: "🔧 Using tool: ask_user ({preview})" + choice_preview: "choice_or_text: {prompt} [{count} options]" + text_preview: "text_input: {prompt}" + thinking: "Thinking" + thinking_elapsed: "Thinking {elapsed}s" + error_correction: press_semicolon_hint: "Command execution failed. Type ; and press Enter to auto-analyze and fix, or enter the next command directly." corrected_command_label: "Suggested fix:" diff --git a/crates/aish-i18n/locales/zh-CN.yaml b/crates/aish-i18n/locales/zh-CN.yaml index ae47aec..9e19011 100644 --- a/crates/aish-i18n/locales/zh-CN.yaml +++ b/crates/aish-i18n/locales/zh-CN.yaml @@ -331,7 +331,7 @@ shell: plan_review_hint: "使用 /plan 开始新的规划会话。" # 思考动画 - thinking_time: "思考: {time:.1}s" + thinking_time: "思考: {time}s" # 分析环境 analyzing_environment: "正在分析环境..." @@ -384,6 +384,21 @@ shell: command_cancelled: "命令已取消" + session: + tool_bash: "🔧 使用工具: bash ({command})" + confirm_execute: "⚠ 是否执行? [Y/n]" + ask_user: + custom_input_label: "输入自定义答案" + default_hint: "[default: {default}]" + help_with_cancel: "Esc 取消,Enter 选择" + help_no_cancel: "Enter 选择" + min_length_error: "答案太短(最少 {min} 个字符)" + tool_banner: "🔧 使用工具: ask_user ({preview})" + choice_preview: "choice_or_text: {prompt} [{count} 个选项]" + text_preview: "text_input: {prompt}" + thinking: "思考中" + thinking_elapsed: "思考中 {elapsed}s" + error_correction: press_semicolon_hint: "命令执行失败。输入 ; 后按 Enter 自动分析修复,或直接输入下一条命令。" corrected_command_label: "修复建议:" diff --git a/crates/aish-pty/Cargo.toml b/crates/aish-pty/Cargo.toml index b1d0e42..f1f16e6 100644 --- a/crates/aish-pty/Cargo.toml +++ b/crates/aish-pty/Cargo.toml @@ -5,6 +5,7 @@ edition.workspace = true [dependencies] aish-core.workspace = true +aish-i18n.workspace = true nix.workspace = true tokio.workspace = true serde.workspace = true diff --git a/crates/aish-pty/src/lib.rs b/crates/aish-pty/src/lib.rs index 1896f11..b933a0f 100644 --- a/crates/aish-pty/src/lib.rs +++ b/crates/aish-pty/src/lib.rs @@ -17,7 +17,9 @@ pub mod command_state; pub mod control; pub mod executor; pub mod offload; +pub mod output_buffer; pub mod persistent; +pub mod session_interceptor; pub mod state_capture; pub mod types; @@ -29,6 +31,11 @@ pub use offload::{ BashOffloadResult, BashOffloadSettings, BashOutputOffload, OffloadResult, OffloadState, PtyOutputOffload, }; +pub use output_buffer::OutputBuffer; +pub use session_interceptor::{ + AiCallback, AiEvent, AiQuery, AiResponse, AskUserAnswer, AskUserChannel, AskUserOption, + AskUserRequest, FollowupCallback, InterceptorState, SessionInterceptor, StdinAction, +}; pub use persistent::{is_interactive_command, shell_quote_escape, PersistentPty}; pub use state_capture::StateChanges; pub use types::{CommandSource, CommandSubmission, PtyCommandResult, StreamName}; diff --git a/crates/aish-pty/src/output_buffer.rs b/crates/aish-pty/src/output_buffer.rs new file mode 100644 index 0000000..58913c9 --- /dev/null +++ b/crates/aish-pty/src/output_buffer.rs @@ -0,0 +1,116 @@ +//! Circular buffer that keeps the most recent N bytes of PTY output. +//! Used to provide context for AI error correction during SSH sessions. + +pub struct OutputBuffer { + data: Vec, + capacity: usize, + write_pos: usize, + len: usize, +} + +impl OutputBuffer { + pub fn new(capacity: usize) -> Self { + assert!(capacity > 0, "OutputBuffer capacity must be > 0"); + Self { + data: vec![0u8; capacity], + capacity, + write_pos: 0, + len: 0, + } + } + + /// Append bytes, overwriting oldest data when full. + pub fn append(&mut self, input: &[u8]) { + for &byte in input { + self.data[self.write_pos] = byte; + self.write_pos = (self.write_pos + 1) % self.capacity; + if self.len < self.capacity { + self.len += 1; + } + } + } + + /// Return the most recent bytes up to `max_len`, in order. + pub fn recent(&self, max_len: usize) -> Vec { + let count = max_len.min(self.len); + let mut result = Vec::with_capacity(count); + let actual_start = if self.len < self.capacity { + self.len.saturating_sub(count) + } else { + (self.write_pos + self.capacity - count) % self.capacity + }; + for i in 0..count { + result.push(self.data[(actual_start + i) % self.capacity]); + } + result + } + + /// Clear the buffer. + pub fn clear(&mut self) { + self.write_pos = 0; + self.len = 0; + } + + /// Current number of bytes stored. + pub fn len(&self) -> usize { + self.len + } + + /// Whether the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len == 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_append_and_read() { + let mut buf = OutputBuffer::new(100); + buf.append(b"hello world"); + assert_eq!(buf.recent(100), b"hello world"); + } + + #[test] + fn test_circular_overwrite() { + let mut buf = OutputBuffer::new(10); + buf.append(b"0123456789"); + assert_eq!(buf.recent(10), b"0123456789"); + buf.append(b"AB"); + assert_eq!(buf.recent(10), b"23456789AB"); + } + + #[test] + fn test_recent_with_max_len() { + let mut buf = OutputBuffer::new(100); + buf.append(b"hello world"); + assert_eq!(buf.recent(5), b"world"); + } + + #[test] + fn test_clear() { + let mut buf = OutputBuffer::new(100); + buf.append(b"data"); + buf.clear(); + assert!(buf.is_empty()); + assert_eq!(buf.len(), 0); + } + + #[test] + fn test_wrap_around_multiple_times() { + let mut buf = OutputBuffer::new(5); + buf.append(b"ABCDE"); + buf.append(b"FGHIJ"); + buf.append(b"KLMNO"); + assert_eq!(buf.recent(5), b"KLMNO"); + } + + #[test] + fn test_empty_buffer() { + let buf = OutputBuffer::new(100); + assert!(buf.is_empty()); + assert_eq!(buf.recent(100), b""); + } +} diff --git a/crates/aish-pty/src/persistent.rs b/crates/aish-pty/src/persistent.rs index 714462c..2d58f51 100644 --- a/crates/aish-pty/src/persistent.rs +++ b/crates/aish-pty/src/persistent.rs @@ -373,8 +373,14 @@ impl PersistentPty { pub fn send_command_interactive( &mut self, command: &str, + ai_callback: Option>, ) -> aish_core::Result<(i32, String, String)> { let is_session = is_session_command(command); + let mut interceptor = if is_session { + crate::SessionInterceptor::new(ai_callback) + } else { + crate::SessionInterceptor::new(None) + }; // Drain stale data from both the PTY master fd and the control // pipe BEFORE registering the new command. A stale PromptReady @@ -430,6 +436,22 @@ impl PersistentPty { // rendering. Only skip a leading CR-LF or LF at the very start // of the first chunk -- never consume actual command output. let mut skip_leading_newline = true; + // When a command is injected, the remote shell echoes it back. + // Store the command here so the echo can be stripped from output. + let mut skip_echo_cmd: Option = None; + // Followup callback state: after AI injects a command, capture its + // output and call the followup when the shell goes idle. + let mut pending_followup: Option> = None; + let mut followup_captured: Vec = Vec::new(); + let mut followup_capturing = false; + // Pending AI response — shared between TriggerAi handler and + // followup handler for multi-round tool chaining. + let mut pending_response: Option = None; + // Consecutive idle poll count — require N empty polls before treating + // the shell as truly idle (prevents premature followup triggers over + // SSH where brief network gaps can exceed 50ms). + let mut idle_poll_count: u32 = 0; + const IDLE_THRESHOLD: u32 = 3; while !done { // Build fd sets. @@ -480,11 +502,199 @@ impl PersistentPty { if sel == 0 { // Timeout -- during drain phase this means all output has - // been delivered. During normal phase it's just a poll - // cycle with nothing to do. + // been delivered. During normal phase increment the idle + // counter to require consecutive empty polls before acting. + idle_poll_count += 1; if draining { done = true; } + // Only treat the shell as idle after N consecutive timeouts + // to avoid false positives from brief SSH network gaps. + if idle_poll_count >= IDLE_THRESHOLD { + // No data for N * 50ms — the remote shell is idle and + // sitting at a prompt waiting for input. + if is_session { + interceptor.mark_prompt_ready(); + } + // If we were capturing output for followup analysis, the + // command has finished — invoke the followup callback. + if followup_capturing { + // Detect stuck state: shell is showing a PS2 continuation + // prompt (e.g. unclosed heredoc/quote). Send Ctrl+C to + // cancel and skip the followup. + if looks_like_continuation_prompt(&followup_captured) { + unsafe { + libc::write( + self.master_fd, + b"\x03".as_ptr() as *const libc::c_void, + 1, + ); + } + followup_capturing = false; + pending_followup = None; + followup_captured.clear(); + } else { + followup_capturing = false; + if let Some(followup) = pending_followup.take() { + let output = + String::from_utf8_lossy(&followup_captured).to_string(); + let clean = strip_ansi_and_prompt(&output); + let next_response = followup(&clean); + if let Some(resp) = next_response { + pending_response = Some(resp); + } else { + unsafe { + libc::write( + self.master_fd, + b"\r".as_ptr() as *const libc::c_void, + 1, + ); + } + skip_leading_newline = true; + } + } + followup_captured.clear(); + } + } + } + // Process pending AI response (multi-round chaining). + // Must happen here — the `continue` below skips the + // normal pending_response block after the master-fd read. + if let Some(response) = pending_response.take() { + // Handle ask_user first — it may produce a new pending_response + if let Some((request, channel)) = response.ask_user { + handle_ask_user_interaction( + request, + channel, + stdin_fd, + self.master_fd, + &mut pending_response, + ); + // If ask_user produced a final response, fall through + // to process it on the next iteration. + continue; + } + if let Some(ref cmd) = response.command { + let tool_text = aish_i18n::t_with_args( + "shell.session.tool_bash", + &{ + let mut m = std::collections::HashMap::new(); + m.insert("command".to_string(), cmd.clone()); + m + }, + ); + let tool_line = format!("\x1b[36m{}\x1b[0m\r\n", tool_text); + unsafe { + libc::write( + libc::STDOUT_FILENO, + tool_line.as_ptr() as *const libc::c_void, + tool_line.len(), + ); + } + let confirm = format!( + "\x1b[33m{}\x1b[0m ", + aish_i18n::t("shell.session.confirm_execute") + ); + unsafe { + libc::write( + libc::STDOUT_FILENO, + confirm.as_ptr() as *const libc::c_void, + confirm.len(), + ); + } + let mut ans = [0u8; 1]; + let approved = match unsafe { + libc::read( + stdin_fd, + ans.as_mut_ptr() as *mut libc::c_void, + 1, + ) + } { + 1 => { + let echo = if ans[0] == b'y' + || ans[0] == b'Y' + || ans[0] == b'\r' + || ans[0] == b'\n' + { + b"y\r\n" + } else { + b"n\r\n" + }; + unsafe { + libc::write( + libc::STDOUT_FILENO, + echo.as_ptr() as *const libc::c_void, + echo.len(), + ); + } + // Drain trailing newline/CR so it doesn't leak + // into the next read cycle. + drain_stdin_trailing(stdin_fd); + ans[0] == b'y' + || ans[0] == b'Y' + || ans[0] == b'\r' + || ans[0] == b'\n' + } + _ => false, + }; + if approved { + let safe_cmd = close_unclosed_heredoc(cmd); + skip_echo_cmd = Some(safe_cmd.clone()); + let mut inject = safe_cmd.as_bytes().to_vec(); + inject.push(b'\r'); + unsafe { + libc::write( + self.master_fd, + inject.as_ptr() as *const libc::c_void, + inject.len(), + ); + } + if response.followup.is_some() { + followup_captured.clear(); + followup_capturing = true; + pending_followup = response.followup; + } + } else { + let cancel_msg = format!( + "\x1b[33m{}\x1b[0m\r\n", + aish_i18n::t("shell.command_cancelled") + ); + unsafe { + libc::write( + libc::STDOUT_FILENO, + cancel_msg.as_ptr() as *const libc::c_void, + cancel_msg.len(), + ); + libc::write( + self.master_fd, + b"\r".as_ptr() as *const libc::c_void, + 1, + ); + } + // Call the followup with a cancellation message so + // the LLM thread receives output instead of + // "Channel closed" when the sender is dropped. + if let Some(followup) = response.followup { + let next_response = followup("Command cancelled by user"); + if let Some(resp) = next_response { + pending_response = Some(resp); + } else { + skip_leading_newline = true; + } + } else { + skip_leading_newline = true; + } + } + } else { + unsafe { + libc::write( + self.master_fd, + b"\r".as_ptr() as *const libc::c_void, + 1, + ); + } + } + } continue; } @@ -506,7 +716,7 @@ impl PersistentPty { } } - // Read stdin -> master (only during normal phase). + // Read stdin -> interceptor or master (only during normal phase). if !draining && unsafe { libc::FD_ISSET(stdin_fd, &read_fds) } { let mut tmp = [0u8; 1024]; match unsafe { @@ -514,10 +724,125 @@ impl PersistentPty { } { n if n > 0 => { let data = &tmp[..n as usize]; - if data.contains(&0x03) && !is_session { - let _ = kill_pg(self.child_pid, Signal::SIGINT); + idle_poll_count = 0; + + // Non-session: original passthrough behavior + if !is_session { + if data.contains(&0x03) { + let _ = kill_pg(self.child_pid, Signal::SIGINT); + } + write_buf.extend_from_slice(data); + continue; + } + + // Session command: route through interceptor + for &byte in data { + match interceptor.feed_stdin(byte) { + crate::StdinAction::Forward => { + write_buf.push(byte); + } + crate::StdinAction::EchoLocally => { + unsafe { + libc::write( + libc::STDOUT_FILENO, + &byte as *const u8 as *const libc::c_void, + 1, + ); + } + } + crate::StdinAction::TriggerAi(question) => { + // When triggered from line-level detection, the + // PTY has already echoed the input line. Send + // Ctrl+C to cancel it on the remote side. + if interceptor.take_cancel_pty_line() { + unsafe { + libc::write( + self.master_fd, + b"\x03".as_ptr() as *const libc::c_void, + 1, + ); + } + // Drain PTY output from Ctrl+C (^C + new prompt). + // Must consume it NOW before calling the blocking + // AI callback, otherwise it appears after the AI + // response and confirmation prompt. + let mut drain_buf = [0u8; 4096]; + loop { + let mut rfds: libc::fd_set = + unsafe { std::mem::zeroed() }; + unsafe { + libc::FD_ZERO(&mut rfds); + libc::FD_SET(self.master_fd, &mut rfds); + } + let mut tv = libc::timeval { + tv_sec: 0, + tv_usec: 100_000, // 100ms + }; + let sel = unsafe { + libc::select( + self.master_fd + 1, + &mut rfds, + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut tv, + ) + }; + if sel > 0 + && unsafe { + libc::FD_ISSET( + self.master_fd, + &mut rfds, + ) + } + { + let n = unsafe { + libc::read( + self.master_fd, + drain_buf.as_mut_ptr() + as *mut libc::c_void, + drain_buf.len(), + ) + }; + if n <= 0 { + break; + } + let data = &drain_buf[..n as usize]; + interceptor.feed_pty_output(data); + continue; + } + break; + } + } + + // Move to a new line (preserve user's input line) + unsafe { + libc::write( + libc::STDOUT_FILENO, + b"\r\n".as_ptr() as *const libc::c_void, + 2, + ); + } + + // Call AI callback — it handles ALL display + // and returns an optional command to inject. + let resp = interceptor.call_ai(question); + interceptor.finish_ai(); + skip_leading_newline = true; + if let Some(response) = resp { + pending_response = Some(response); + } else { + // AI returned None — restore prompt + unsafe { + libc::write( + self.master_fd, + b"\r".as_ptr() as *const libc::c_void, + 1, + ); + } + } + } + } } - write_buf.extend_from_slice(data); } _ => {} } @@ -534,6 +859,7 @@ impl PersistentPty { ) } { n if n > 0 => { + idle_poll_count = 0; let mut data = &tmp[..n as usize]; if skip_leading_newline { // Only strip a bare leading CR-LF or LF that @@ -546,15 +872,39 @@ impl PersistentPty { } skip_leading_newline = false; } + // Strip the remote shell's echo of an injected command. + if let Some(ref echo_cmd) = skip_echo_cmd { + let pattern = format!("{}\r\n", echo_cmd).into_bytes(); + if data.starts_with(&pattern) { + data = &data[pattern.len()..]; + } else { + let pattern_cr = format!("{}\r", echo_cmd).into_bytes(); + if data.starts_with(&pattern_cr) { + data = &data[pattern_cr.len()..]; + } + } + skip_echo_cmd = None; + } if !data.is_empty() { output_buf.extend_from_slice(data); - let _ = unsafe { - libc::write( - libc::STDOUT_FILENO, - data.as_ptr() as *const libc::c_void, - data.len(), - ) - }; + // Feed interceptor for line-start tracking and output buffering + if is_session { + interceptor.feed_pty_output(data); + } + // Capture output for followup analysis + if followup_capturing { + followup_captured.extend_from_slice(data); + } + // Display unless AI is processing + if !interceptor.is_ai_processing() { + let _ = unsafe { + libc::write( + libc::STDOUT_FILENO, + data.as_ptr() as *const libc::c_void, + data.len(), + ) + }; + } } } 0 => { @@ -567,6 +917,146 @@ impl PersistentPty { } } + // Process pending AI response (from TriggerAi or followup chain). + if let Some(response) = pending_response.take() { + // Handle ask_user first + if let Some((request, channel)) = response.ask_user { + handle_ask_user_interaction( + request, + channel, + stdin_fd, + self.master_fd, + &mut pending_response, + ); + // pending_response may now contain the final AI response + // which will be processed on the next loop iteration. + } else if let Some(ref cmd) = response.command { + // Show tool indicator matching local aish style + let tool_text = aish_i18n::t_with_args( + "shell.session.tool_bash", + &{ + let mut m = std::collections::HashMap::new(); + m.insert("command".to_string(), cmd.clone()); + m + }, + ); + let tool_line = format!("\x1b[36m{}\x1b[0m\r\n", tool_text); + unsafe { + libc::write( + libc::STDOUT_FILENO, + tool_line.as_ptr() as *const libc::c_void, + tool_line.len(), + ); + } + + // Confirmation prompt before execution + let confirm = format!( + "\x1b[33m{}\x1b[0m ", + aish_i18n::t("shell.session.confirm_execute") + ); + unsafe { + libc::write( + libc::STDOUT_FILENO, + confirm.as_ptr() as *const libc::c_void, + confirm.len(), + ); + } + + // Read one byte for confirmation (raw mode) + let mut ans = [0u8; 1]; + let approved = match unsafe { + libc::read( + stdin_fd, + ans.as_mut_ptr() as *mut libc::c_void, + 1, + ) + } { + 1 => { + let echo = if ans[0] == b'y' + || ans[0] == b'Y' + || ans[0] == b'\r' + || ans[0] == b'\n' + { + b"y\r\n" + } else { + b"n\r\n" + }; + unsafe { + libc::write( + libc::STDOUT_FILENO, + echo.as_ptr() as *const libc::c_void, + echo.len(), + ); + } + drain_stdin_trailing(stdin_fd); + ans[0] == b'y' + || ans[0] == b'Y' + || ans[0] == b'\r' + || ans[0] == b'\n' + } + _ => false, + }; + + if approved { + let safe_cmd = close_unclosed_heredoc(cmd); + skip_echo_cmd = Some(safe_cmd.clone()); + let mut inject = safe_cmd.as_bytes().to_vec(); + inject.push(b'\r'); + unsafe { + libc::write( + self.master_fd, + inject.as_ptr() as *const libc::c_void, + inject.len(), + ); + } + if response.followup.is_some() { + followup_captured.clear(); + followup_capturing = true; + pending_followup = response.followup; + } + } else { + let cancel_msg = format!( + "\x1b[33m{}\x1b[0m\r\n", + aish_i18n::t("shell.command_cancelled") + ); + unsafe { + libc::write( + libc::STDOUT_FILENO, + cancel_msg.as_ptr() as *const libc::c_void, + cancel_msg.len(), + ); + libc::write( + self.master_fd, + b"\r".as_ptr() as *const libc::c_void, + 1, + ); + } + // Call the followup with a cancellation message so + // the LLM thread receives output instead of + // "Channel closed" when the sender is dropped. + if let Some(followup) = response.followup { + let next_response = followup("Command cancelled by user"); + if let Some(resp) = next_response { + pending_response = Some(resp); + } else { + skip_leading_newline = true; + } + } else { + skip_leading_newline = true; + } + } + } else { + // AI returned explanation only (no command) + unsafe { + libc::write( + self.master_fd, + b"\r".as_ptr() as *const libc::c_void, + 1, + ); + } + } + } + // Read control pipe for events (only during normal phase). if !draining && unsafe { libc::FD_ISSET(self.control_fd, &read_fds) } { let mut tmp = [0u8; 4096]; @@ -928,6 +1418,781 @@ impl PersistentPty { } +// ---- ask_user helpers for SSH sessions ---- + +/// Drain trailing bytes (e.g. `\n` or `\r`) from stdin after a single-byte +/// confirmation read so they don't leak into the next input cycle. +fn drain_stdin_trailing(stdin_fd: libc::c_int) { + let mut discard = [0u8; 1]; + loop { + let mut rfds: libc::fd_set = unsafe { std::mem::zeroed() }; + unsafe { + libc::FD_ZERO(&mut rfds); + libc::FD_SET(stdin_fd, &mut rfds); + } + let mut tv = libc::timeval { + tv_sec: 0, + tv_usec: 10_000, // 10ms + }; + let sel = unsafe { + libc::select( + stdin_fd + 1, + &mut rfds, + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut tv, + ) + }; + if sel <= 0 { + break; + } + match unsafe { libc::read(stdin_fd, discard.as_mut_ptr() as *mut libc::c_void, 1) } { + 1 => { + if discard[0] == b'\n' || discard[0] == b'\r' { + break; + } + // Non-newline byte — stop draining + break; + } + _ => break, + } + } +} + +/// Truncate a string to `max` bytes, respecting UTF-8 boundaries. +fn truncate_str(s: &str, max: usize) -> &str { + if s.len() <= max { + return s; + } + let mut end = max; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + &s[..end] +} + +/// Debug helper: describe the answer kind without exposing the value. +fn answer_kind(answer: &crate::AskUserAnswer) -> &'static str { + match answer { + crate::AskUserAnswer::Response(_) => "Response", + crate::AskUserAnswer::Cancelled => "Cancelled", + } +} + +/// How many lines the ask_user display occupies (so we can erase/redraw). +fn count_display_lines(request: &crate::AskUserRequest) -> usize { + let mut lines = 1; // Header + if request.kind == "choice_or_text" { + if let Some(ref options) = request.options { + lines += options.len() + 1; // options + custom input + } + } + if request.default.is_some() { + lines += 1; + } + lines += 1; // Help line + // Prompt line "> " only for text_input mode + if request.kind != "choice_or_text" { + lines += 1; + } + lines +} + +/// Erase the current ask_user display and redraw with the cursor at +/// `cursor` (only meaningful for choice_or_text). +fn redraw_ask_user(request: &crate::AskUserRequest, prev_lines: usize, cursor: usize) { + let mut out = Vec::new(); + + // Move up and clear + if prev_lines > 0 { + out.extend_from_slice(format!("\x1b[{}A", prev_lines).as_bytes()); + } + out.extend_from_slice(b"\x1b[J"); // Clear from cursor to end of screen + + // Header — match local aish's inquire style + out.extend_from_slice(b"\x1b[36m? \x1b[1m"); + if let Some(ref title) = request.title { + out.extend_from_slice(title.as_bytes()); + out.extend_from_slice(b": "); + } + out.extend_from_slice(request.prompt.as_bytes()); + out.extend_from_slice(b"\x1b[0m\r\n"); + + // Options with cursor highlight for choice_or_text + if request.kind == "choice_or_text" { + if let Some(ref options) = request.options { + for (i, opt) in options.iter().enumerate() { + // Use inquire-style cursor: ">" for selected, " " for others + if i == cursor { + out.extend_from_slice(b"\x1b[36m> \x1b[1m"); + } else { + out.extend_from_slice(b" "); + } + out.extend_from_slice(opt.label.as_bytes()); + if let Some(ref desc) = opt.description { + out.extend_from_slice(format!(" - {}", desc).as_bytes()); + } + out.extend_from_slice(b"\x1b[0m\r\n"); + } + // Custom input entry at the bottom — same label as local aish + let custom_label = aish_i18n::t("shell.session.ask_user.custom_input_label"); + if cursor == options.len() { + out.extend_from_slice(b"\x1b[36m> \x1b[1m"); + } else { + out.extend_from_slice(b" "); + } + out.extend_from_slice(format!("({})", custom_label).as_bytes()); + out.extend_from_slice(b"\x1b[0m\r\n"); + } + } + + // Default hint — match local aish's [default: xxx] format + if let Some(ref default) = request.default { + let default_hint = aish_i18n::t_with_args( + "shell.session.ask_user.default_hint", + &{ + let mut m = std::collections::HashMap::new(); + m.insert("default".to_string(), default.clone()); + m + }, + ); + out.extend_from_slice(b"\x1b[2m"); + out.extend_from_slice(default_hint.as_bytes()); + out.extend_from_slice(b"\x1b[0m\r\n"); + } + + // Help message — match local aish's style + if request.allow_cancel { + let help = aish_i18n::t("shell.session.ask_user.help_with_cancel"); + out.extend_from_slice(b"\x1b[2m"); + out.extend_from_slice(help.as_bytes()); + out.extend_from_slice(b"\x1b[0m\r\n"); + } else { + let help = aish_i18n::t("shell.session.ask_user.help_no_cancel"); + out.extend_from_slice(b"\x1b[2m"); + out.extend_from_slice(help.as_bytes()); + out.extend_from_slice(b"\x1b[0m\r\n"); + } + + // Prompt (only for text_input mode) + if request.kind != "choice_or_text" { + out.extend_from_slice(b"\x1b[33m> \x1b[0m"); + } + + unsafe { + libc::write( + libc::STDOUT_FILENO, + out.as_ptr() as *const libc::c_void, + out.len(), + ); + } +} + +/// Initial display — ensure we start on a fresh line. +fn display_ask_user(request: &crate::AskUserRequest) { + // Move to a new line to avoid garbling with previous AI output + unsafe { + libc::write(libc::STDOUT_FILENO, b"\r\n".as_ptr() as *const libc::c_void, 2); + } + redraw_ask_user(request, 0, 0); +} + +/// Read one raw byte from stdin with EINTR retry. +/// Returns the byte on success, or None on EOF/error. +fn read_byte(stdin_fd: libc::c_int) -> Option { + loop { + let mut byte = [0u8; 1]; + let n = unsafe { libc::read(stdin_fd, byte.as_mut_ptr() as *mut libc::c_void, 1) }; + match n { + 1 => return Some(byte[0]), + -1 => { + let errno = unsafe { *libc::__errno_location() }; + if errno == libc::EINTR { + continue; + } + debug!("read_byte: error, errno={}", errno); + return None; + } + 0 => { + debug!("read_byte: EOF"); + return None; + } + _ => { + debug!("read_byte: unexpected return {}", n); + continue; + } + } + } +} + +/// Check whether stdin has data available within `timeout_us` microseconds. +fn stdin_poll(stdin_fd: libc::c_int, timeout_us: libc::suseconds_t) -> bool { + let mut rfds: libc::fd_set = unsafe { std::mem::zeroed() }; + unsafe { + libc::FD_ZERO(&mut rfds); + libc::FD_SET(stdin_fd, &mut rfds); + } + let mut tv = libc::timeval { + tv_sec: 0, + tv_usec: timeout_us, + }; + let sel = unsafe { + libc::select( + stdin_fd + 1, + &mut rfds, + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut tv, + ) + }; + sel > 0 +} + +/// Consume a CSI escape sequence (already read `\x1b[`). +/// CSI format: parameters (0x30-0x3F)* intermediate (0x20-0x2F)* final (0x40-0x7E) +/// Returns the final byte (e.g. 'A' for up arrow) or None on error. +fn consume_csi(stdin_fd: libc::c_int) -> Option { + loop { + match read_byte(stdin_fd) { + Some(b) if b >= 0x40 && b <= 0x7E => return Some(b), + Some(_) => continue, // parameter or intermediate byte + None => return None, + } + } +} + +/// Read a line of user input in raw mode with escape-sequence handling. +/// For choice_or_text: up/down arrows navigate options (including custom +/// input slot at the bottom), Enter selects. Typing text switches to +/// custom input mode. +/// For text_input: normal text editing, Enter submits. +/// Ctrl+C always cancels. Esc cancels only if allow_cancel is true. +fn read_line_from_stdin_raw( + stdin_fd: libc::c_int, + request: &crate::AskUserRequest, +) -> crate::AskUserAnswer { + let is_choice = request.kind == "choice_or_text"; + let num_options = request + .options + .as_ref() + .map_or(0, |o| o.len()); + let has_options = is_choice && num_options > 0; + // Total selectable slots: options + 1 custom-input slot + let total_slots = if has_options { num_options + 1 } else { 0 }; + + // For choice mode, track cursor position + let mut cursor: usize = 0; + let mut text_buf: Vec = Vec::new(); + + loop { + match read_byte(stdin_fd) { + Some(byte) => match byte { + // Ctrl+C → always cancel + 0x03 => { + unsafe { + libc::write(libc::STDOUT_FILENO, b"^C\r\n".as_ptr() as *const _, 5); + } + return crate::AskUserAnswer::Cancelled; + } + // Enter → submit + b'\r' | b'\n' => { + // After printing \r\n the cursor is one line below the + // display. prev_lines must account for the full display + // height so redraw can move back to the header line. + let prev = count_display_lines(request) + 1; // +1 for the \r\n + unsafe { + libc::write(libc::STDOUT_FILENO, b"\r\n".as_ptr() as *const _, 2); + } + // If user typed text, treat as custom input + if !text_buf.is_empty() { + let text = String::from_utf8_lossy(&text_buf).to_string(); + let trimmed = text.trim().to_string(); + if trimmed.is_empty() { + // Empty after trim — treat like empty + if let Some(ref default) = request.default { + return crate::AskUserAnswer::Response(default.clone()); + } + if request.allow_cancel { + return crate::AskUserAnswer::Cancelled; + } + // Required — redisplay and loop + redraw_ask_user(request, prev, cursor); + text_buf.clear(); + continue; + } + if trimmed.len() < request.min_length { + let min_len_msg = aish_i18n::t_with_args( + "shell.session.ask_user.min_length_error", + &{ + let mut m = std::collections::HashMap::new(); + m.insert("min".to_string(), request.min_length.to_string()); + m + }, + ); + let msg = format!("\x1b[31m{}\x1b[0m\r\n", min_len_msg); + unsafe { + libc::write( + libc::STDOUT_FILENO, + msg.as_ptr() as *const libc::c_void, + msg.len(), + ); + } + redraw_ask_user(request, prev, cursor); + text_buf.clear(); + continue; + } + return crate::AskUserAnswer::Response(trimmed); + } + // No text typed — select by cursor position + if has_options { + if cursor < num_options { + // Regular option selected + let value = request.options.as_ref().unwrap()[cursor].value.clone(); + return crate::AskUserAnswer::Response(value); + } else { + // Custom-input slot selected with no text — + // stay in input mode (same as local AskUserTool + // which goes back to select on empty input) + redraw_ask_user(request, prev, cursor); + continue; + } + } + // text_input mode with empty input + if let Some(ref default) = request.default { + return crate::AskUserAnswer::Response(default.clone()); + } + if request.allow_cancel { + return crate::AskUserAnswer::Cancelled; + } + // Required — redisplay and loop + redraw_ask_user(request, prev, cursor); + continue; + } + // Backspace / Delete + 0x7F | 0x08 => { + if !text_buf.is_empty() { + // Pop trailing UTF-8 continuation bytes, then leader + while text_buf.last().map_or(false, |b| b & 0xC0 == 0x80) { + text_buf.pop(); + } + let leader = text_buf.pop().unwrap(); + // Display width: ASCII=1, 2-byte=1, 3-byte(CJK)=2, 4-byte=2 + let width = if leader < 0x80 { 1 } + else if leader & 0xE0 == 0xC0 { 1 } + else if leader & 0xF0 == 0xE0 { 2 } + else { 2 }; + let erase = format!( + "{}{}{}", + "\x08".repeat(width), + " ".repeat(width), + "\x08".repeat(width), + ); + unsafe { + libc::write( + libc::STDOUT_FILENO, + erase.as_ptr() as *const _, + erase.len(), + ); + } + } + } + // Escape — could be standalone Esc or start of escape sequence + 0x1B => { + // Use 100ms timeout: long enough to cover SSH network + // latency (direction keys arrive as ESC [ A in separate + // packets) while still allowing standalone ESC to cancel. + if stdin_poll(stdin_fd, 100_000) { + // Escape sequence — read next byte + match read_byte(stdin_fd) { + Some(b'[') => { + // CSI sequence + match consume_csi(stdin_fd) { + Some(b'A') | Some(b'k') => { + // Up arrow — navigate in choice mode + if total_slots > 0 { + if cursor > 0 { + cursor -= 1; + } else { + cursor = total_slots - 1; + } + // Clear text buffer when navigating + if !text_buf.is_empty() { + let erase = "\x08".repeat(text_buf.len()) + + &" ".repeat(text_buf.len()) + + &"\x08".repeat(text_buf.len()); + unsafe { + libc::write( + libc::STDOUT_FILENO, + erase.as_ptr() as *const _, + erase.len(), + ); + } + text_buf.clear(); + } + let prev = if request.kind == "choice_or_text" { + count_display_lines(request) + } else { + count_display_lines(request).saturating_sub(1) + }; + redraw_ask_user(request, prev, cursor); + } + } + Some(b'B') | Some(b'j') => { + // Down arrow — navigate in choice mode + if total_slots > 0 { + if cursor + 1 < total_slots { + cursor += 1; + } else { + cursor = 0; + } + if !text_buf.is_empty() { + let erase = "\x08".repeat(text_buf.len()) + + &" ".repeat(text_buf.len()) + + &"\x08".repeat(text_buf.len()); + unsafe { + libc::write( + libc::STDOUT_FILENO, + erase.as_ptr() as *const _, + erase.len(), + ); + } + text_buf.clear(); + } + let prev = if request.kind == "choice_or_text" { + count_display_lines(request) + } else { + count_display_lines(request).saturating_sub(1) + }; + redraw_ask_user(request, prev, cursor); + } + } + _ => { + // Other CSI sequences (Home, End, PgUp, etc.) — ignore + } + } + } + Some(b'O') => { + // SS3 sequence (F-keys, etc.) — consume final byte and ignore + let _ = read_byte(stdin_fd); + } + Some(_) => { + // Other escape sequences — ignore + } + None => { + // Incomplete sequence — treat as Esc + if request.allow_cancel { + unsafe { + libc::write( + libc::STDOUT_FILENO, + b"\r\n".as_ptr() as *const _, + 2, + ); + } + return crate::AskUserAnswer::Cancelled; + } + // Not allowed to cancel — ignore + } + } + } else { + // Standalone Escape + if request.allow_cancel { + unsafe { + libc::write( + libc::STDOUT_FILENO, + b"\r\n".as_ptr() as *const _, + 2, + ); + } + return crate::AskUserAnswer::Cancelled; + } + // Not allowed to cancel — ignore + } + } + // Normal byte — typing text automatically switches to custom input + _ => { + text_buf.push(byte); + // Echo + unsafe { + libc::write( + libc::STDOUT_FILENO, + &byte as *const u8 as *const libc::c_void, + 1, + ); + } + } + }, + None => { + // EOF or error + return crate::AskUserAnswer::Cancelled; + } + } + } +} + +/// Handle an ask_user interaction: display question, read answer, wait for +/// next LLM event. Sets `pending_response` with the final AI response (or +/// None if the LLM finished without further action). +fn handle_ask_user_interaction( + request: crate::AskUserRequest, + channel: crate::AskUserChannel, + stdin_fd: libc::c_int, + master_fd: libc::c_int, + pending_response: &mut Option, +) { + debug!( + "handle_ask_user: kind={}, prompt={}", + request.kind, request.prompt + ); + + // Show tool indicator matching local aish style + let args_preview = match request.kind.as_str() { + "choice_or_text" => { + let n = request.options.as_ref().map_or(0, |o| o.len()); + let mut m = std::collections::HashMap::new(); + m.insert("prompt".to_string(), truncate_str(&request.prompt, 60).to_string()); + m.insert("count".to_string(), n.to_string()); + aish_i18n::t_with_args("shell.session.ask_user.choice_preview", &m) + } + _ => { + let mut m = std::collections::HashMap::new(); + m.insert("prompt".to_string(), truncate_str(&request.prompt, 80).to_string()); + aish_i18n::t_with_args("shell.session.ask_user.text_preview", &m) + } + }; + let mut tool_args = std::collections::HashMap::new(); + tool_args.insert("preview".to_string(), args_preview); + let tool_line = format!( + "\x1b[36m{}\x1b[0m\r\n", + aish_i18n::t_with_args("shell.session.ask_user.tool_banner", &tool_args) + ); + unsafe { + libc::write( + libc::STDOUT_FILENO, + tool_line.as_ptr() as *const libc::c_void, + tool_line.len(), + ); + } + + display_ask_user(&request); + let answer = read_line_from_stdin_raw(stdin_fd, &request); + debug!("handle_ask_user: got answer {:?}", answer_kind(&answer)); + + if channel.answer_sender.send(answer).is_err() { + debug!("handle_ask_user: answer channel closed"); + return; + } + + // Wait for next event from LLM, forwarding PTY output meanwhile. + loop { + match channel.event_receiver.try_recv() { + Ok(crate::AiEvent::Done(resp)) => { + debug!("handle_ask_user: LLM done, has_command={}", resp.is_some()); + *pending_response = resp; + break; + } + Ok(crate::AiEvent::AskUser(next_req)) => { + debug!( + "handle_ask_user: follow-up ask_user, prompt={}", + next_req.prompt + ); + display_ask_user(&next_req); + let answer = read_line_from_stdin_raw(stdin_fd, &next_req); + debug!("handle_ask_user: follow-up answer {:?}", answer_kind(&answer)); + if channel.answer_sender.send(answer).is_err() { + break; + } + continue; + } + Ok(crate::AiEvent::BashExec { command, output_sender }) => { + debug!("handle_ask_user: follow-up bash_exec, cmd={}", command); + // Show tool indicator and confirmation, then execute inline. + let tool_text = aish_i18n::t_with_args( + "shell.session.tool_bash", + &{ + let mut m = std::collections::HashMap::new(); + m.insert("command".to_string(), command.clone()); + m + }, + ); + let tool_line = format!("\x1b[36m{}\x1b[0m\r\n", tool_text); + unsafe { + libc::write( + libc::STDOUT_FILENO, + tool_line.as_ptr() as *const libc::c_void, + tool_line.len(), + ); + } + let confirm = format!( + "\x1b[33m{}\x1b[0m ", + aish_i18n::t("shell.session.confirm_execute") + ); + unsafe { + libc::write( + libc::STDOUT_FILENO, + confirm.as_ptr() as *const libc::c_void, + confirm.len(), + ); + } + let mut ans = [0u8; 1]; + let approved = match unsafe { + libc::read(stdin_fd, ans.as_mut_ptr() as *mut libc::c_void, 1) + } { + 1 => { + let echo = if ans[0] == b'y' + || ans[0] == b'Y' + || ans[0] == b'\r' + || ans[0] == b'\n' + { + b"y\r\n" + } else { + b"n\r\n" + }; + unsafe { + libc::write( + libc::STDOUT_FILENO, + echo.as_ptr() as *const libc::c_void, + echo.len(), + ); + } + drain_stdin_trailing(stdin_fd); + ans[0] == b'y' + || ans[0] == b'Y' + || ans[0] == b'\r' + || ans[0] == b'\n' + } + _ => false, + }; + if approved { + let safe_cmd = close_unclosed_heredoc(&command); + let mut inject = safe_cmd.as_bytes().to_vec(); + inject.push(b'\r'); + unsafe { + libc::write( + master_fd, + inject.as_ptr() as *const libc::c_void, + inject.len(), + ); + } + // Wait for command output until the shell goes idle. + let mut captured = Vec::new(); + let mut idle_count: u32 = 0; + loop { + let mut rfds: libc::fd_set = unsafe { std::mem::zeroed() }; + unsafe { + libc::FD_ZERO(&mut rfds); + libc::FD_SET(master_fd, &mut rfds); + } + let mut tv = libc::timeval { + tv_sec: 0, + tv_usec: 50_000, + }; + let sel = unsafe { + libc::select( + master_fd + 1, + &mut rfds, + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut tv, + ) + }; + if sel > 0 + && unsafe { libc::FD_ISSET(master_fd, &mut rfds) } + { + let mut tmp = [0u8; 4096]; + match unsafe { + libc::read( + master_fd, + tmp.as_mut_ptr() as *mut libc::c_void, + tmp.len(), + ) + } { + n if n > 0 => { + let data = &tmp[..n as usize]; + captured.extend_from_slice(data); + unsafe { + libc::write( + libc::STDOUT_FILENO, + data.as_ptr() as *const libc::c_void, + data.len(), + ); + } + idle_count = 0; + } + _ => break, + } + } else { + idle_count += 1; + if idle_count >= 3 { + break; + } + } + } + let output = String::from_utf8_lossy(&captured).to_string(); + let clean = strip_ansi_and_prompt(&output); + let _ = output_sender.send(clean); + } else { + let cancel_msg = format!( + "\x1b[33m{}\x1b[0m\r\n", + aish_i18n::t("shell.command_cancelled") + ); + unsafe { + libc::write( + libc::STDOUT_FILENO, + cancel_msg.as_ptr() as *const libc::c_void, + cancel_msg.len(), + ); + } + let _ = output_sender.send(format!("(cancelled: {})", command)); + } + continue; + } + Err(std::sync::mpsc::TryRecvError::Empty) => { + // Forward PTY output while waiting for LLM + let mut rfds: libc::fd_set = unsafe { std::mem::zeroed() }; + unsafe { + libc::FD_ZERO(&mut rfds); + libc::FD_SET(master_fd, &mut rfds); + } + let mut tv = libc::timeval { + tv_sec: 0, + tv_usec: 50_000, // 50ms + }; + let sel = unsafe { + libc::select( + master_fd + 1, + &mut rfds, + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut tv, + ) + }; + if sel > 0 && unsafe { libc::FD_ISSET(master_fd, &mut rfds) } { + let mut tmp = [0u8; 4096]; + match unsafe { + libc::read( + master_fd, + tmp.as_mut_ptr() as *mut libc::c_void, + tmp.len(), + ) + } { + n if n > 0 => { + unsafe { + libc::write( + libc::STDOUT_FILENO, + tmp.as_ptr() as *const libc::c_void, + n as usize, + ); + } + } + _ => {} + } + } + } + Err(std::sync::mpsc::TryRecvError::Disconnected) => break, + } + } +} + impl Drop for PersistentPty { fn drop(&mut self) { self.stop(); @@ -1096,6 +2361,22 @@ fn strip_ansi_escapes(s: &str) -> String { result } +/// Strip ANSI escapes and trim trailing shell prompt from captured output. +/// Removes the last non-empty line (typically a prompt like `user@host:~$ `). +fn strip_ansi_and_prompt(raw: &str) -> String { + let clean = strip_ansi_escapes(raw); + let mut lines: Vec<&str> = clean.lines().collect(); + // Remove trailing empty lines + while lines.last().map_or(false, |l| l.trim().is_empty()) { + lines.pop(); + } + // Remove last non-empty line (shell prompt) + if !lines.is_empty() { + lines.pop(); + } + lines.join("\n").trim().to_string() +} + /// Clean PTY output: strip ANSI, command echo, trailing prompt. fn clean_pty_output(raw: &str, command: &str) -> String { // Strip ANSI escape sequences. @@ -1123,6 +2404,80 @@ fn regex_simple() -> regex::Regex { regex::Regex::new(r"\x1b\[[0-9;?]*[a-zA-Z]").unwrap() } +/// Detect unclosed heredoc in a shell command and close it. +/// Returns the command with missing heredoc closing delimiters appended. +/// e.g. "cat > f << 'EOF'" → "cat > f << 'EOF'\nEOF" +fn close_unclosed_heredoc(cmd: &str) -> String { + let bytes = cmd.as_bytes(); + let mut i = 0; + let len = bytes.len(); + let mut result = cmd.to_string(); + let mut appended = false; + + while i + 1 < len { + if bytes[i] == b'<' && bytes[i + 1] == b'<' { + // Skip << and optional - + let mut j = i + 2; + if j < len && bytes[j] == b'-' { + j += 1; + } + // Skip whitespace + while j < len && bytes[j] == b' ' { + j += 1; + } + // Skip optional quote + if j < len && (bytes[j] == b'\'' || bytes[j] == b'"') { + j += 1; + } + // Extract delimiter word + let delim_start = j; + while j < len + && ![ + b' ', b'\n', b'\r', b';', b'&', b'|', b'<', b'>', b'#', + ] + .contains(&bytes[j]) + && bytes[j] != b'\'' + && bytes[j] != b'"' + { + j += 1; + } + let delimiter = &cmd[delim_start..j]; + + if !delimiter.is_empty() { + // Check if delimiter appears as a standalone line after the << + let search_start = if appended { 0 } else { j.min(len) }; + let rest = &result[search_start..]; + let closed = rest.lines().any(|line| line.trim() == delimiter); + if !closed { + result.push('\n'); + result.push_str(delimiter); + appended = true; + } + } + i = j; + } else { + i += 1; + } + } + + result +} + +/// Detect if PTY output looks like a continuation prompt (PS2: `> `). +/// Used to detect stuck heredoc/quote states after command injection. +fn looks_like_continuation_prompt(output: &[u8]) -> bool { + if output.is_empty() { + return false; + } + let stripped = strip_ansi_escapes(&String::from_utf8_lossy(output)); + let lines: Vec<&str> = stripped.lines().collect(); + if let Some(last_line) = lines.last() { + let trimmed_line = last_line.trim(); + return trimmed_line == ">" || trimmed_line.ends_with("> "); + } + false +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/aish-pty/src/session_interceptor.rs b/crates/aish-pty/src/session_interceptor.rs new file mode 100644 index 0000000..889a568 --- /dev/null +++ b/crates/aish-pty/src/session_interceptor.rs @@ -0,0 +1,576 @@ +use crate::output_buffer::OutputBuffer; + +// --------------------------------------------------------------------------- +// Callback types +// --------------------------------------------------------------------------- + +/// Input provided to the AI callback. +pub struct AiQuery { + /// The user's question text (after the `;`/`;` prefix). + pub question: String, + /// Recent PTY output for context (error correction). + pub recent_output: String, +} + +/// Result returned by the AI callback containing an optional command to +/// inject into the remote shell and raw display text to be shown later +/// (after command output has been displayed). +pub struct AiResponse { + /// Command to inject into the remote PTY. `None` when the AI only + /// provides an explanation without a runnable command. + pub command: Option, + /// Raw LLM response text for deferred display (markdown, will be + /// rendered by the forwarding loop after command output). + pub display_text: String, + /// When Some(command), the forwarding loop should execute the command + /// on the remote host, capture its output, and pass it to this + /// followup callback for analysis. + pub followup: Option>, + /// When Some, the AI needs user input before continuing. The + /// forwarding loop displays the question, reads user input, sends + /// the answer back via the channel, and waits for the next event. + pub ask_user: Option<(AskUserRequest, AskUserChannel)>, +} + +/// A question the AI wants to ask the user during an SSH session. +pub struct AskUserRequest { + /// Interaction type: "text_input" or "choice_or_text". + pub kind: String, + /// The question to display. + pub prompt: String, + /// Predefined choices for "choice_or_text" mode. + pub options: Option>, + /// Optional title for the question. + pub title: Option, + /// Default value (pre-selected). + pub default: Option, + /// Whether the user can cancel/skip (default: true). + pub allow_cancel: bool, + /// Minimum length for text input (default: 0). + pub min_length: usize, +} + +/// One option in a choice_or_text ask_user interaction. +pub struct AskUserOption { + pub value: String, + pub label: String, + pub description: Option, +} + +/// Answer from the user to an ask_user question. +pub enum AskUserAnswer { + Response(String), + Cancelled, +} + +/// Event from the LLM thread to the forwarding loop. +pub enum AiEvent { + /// The LLM wants to ask the user a question. + AskUser(AskUserRequest), + /// The LLM wants to execute a bash command on the remote host. + BashExec { + command: String, + output_sender: std::sync::mpsc::Sender, + }, + /// The LLM has finished processing. Payload is a fully processed AiResponse + /// (with command, followup, etc. already populated). + Done(Option), +} + +/// Channel pair for ask_user communication between the LLM thread and +/// the forwarding loop. +pub struct AskUserChannel { + /// Send user's answer back to the LLM thread. + pub answer_sender: std::sync::mpsc::Sender, + /// Receive next event (another ask_user or done) from the LLM thread. + pub event_receiver: std::sync::mpsc::Receiver, +} + +/// Second-stage callback invoked after the injected command finishes on +/// the remote host. Receives the captured command output, streams the +/// AI analysis to the terminal, and optionally returns a new `AiResponse` +/// to chain another command execution (multi-round tool use). +pub type FollowupCallback = dyn Fn(&str) -> Option + Send + Sync; + +/// AI callback type: receives an AiQuery and returns an optional AiResponse. +pub type AiCallback = dyn Fn(AiQuery) -> Option + Send + Sync; + +// --------------------------------------------------------------------------- +// State machine +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InterceptorState { + /// Normal passthrough — all stdin bytes go to PTY master. + Passthrough, + /// AI callback is running; buffer PTY output but don't display. + AiProcessing, +} + +/// Action returned after processing a stdin byte. +#[derive(Debug, PartialEq, Eq)] +pub enum StdinAction { + /// Forward byte to PTY master. + Forward, + /// Byte was intercepted (AI mode); echo it locally to stdout. + EchoLocally, + /// AI input line is complete. Caller should invoke AI callback. + TriggerAi(String), +} + +// --------------------------------------------------------------------------- +// SessionInterceptor +// --------------------------------------------------------------------------- + +pub struct SessionInterceptor { + state: InterceptorState, + /// Shadow buffer of the current input line in Passthrough mode. + /// Used for line-level AI trigger detection: when Enter is pressed, + /// we check whether the accumulated line starts with `;` or `;`. + line_shadow: Vec, + /// Flag set when AI is triggered from line-level detection (the PTY + /// already has the echoed text). The forwarding loop should send + /// Ctrl+C to the PTY to cancel the line before invoking AI. + cancel_pty_line: bool, + output_buffer: OutputBuffer, + ai_callback: Option>, + /// Escape sequence tracker: when Some, we're consuming bytes of a + /// terminal escape sequence (arrow keys, function keys, etc.) so they + /// don't corrupt line_shadow. + escape_seq: Option, +} + +/// Phase of an escape sequence being consumed. +#[derive(Debug, Clone, Copy)] +enum EscSeqPhase { + /// Received ESC (0x1B), waiting for the next byte. + Start, + /// Received ESC [ — consuming CSI parameter/intermediate bytes. + Csi, +} + +impl SessionInterceptor { + /// Create a new interceptor. + /// `ai_callback` is None -> interceptor is disabled (pure passthrough). + /// `ai_callback` is Some -> interceptor will intercept `;` input. + pub fn new(ai_callback: Option>) -> Self { + Self { + state: InterceptorState::Passthrough, + line_shadow: Vec::with_capacity(4096), + cancel_pty_line: false, + output_buffer: OutputBuffer::new(8192), + ai_callback, + escape_seq: None, + } + } + + /// Feed a single byte from stdin. Returns the action the caller should take. + pub fn feed_stdin(&mut self, byte: u8) -> StdinAction { + // Escape sequence tracking takes precedence over state machine. + // Input methods (Chinese IME, etc.) and terminal keys (arrows, F-keys) + // send multi-byte escape sequences starting with 0x1B. Consuming them + // here prevents them from corrupting line_shadow (Passthrough) + // or cancelling the AI input prematurely (AiInput). + if let Some(phase) = self.escape_seq.take() { + return self.handle_escape_seq_byte(byte, phase); + } + + match self.state { + InterceptorState::Passthrough => { + match byte { + b'\r' | b'\n' => { + // End of line — check whether the accumulated input + // starts with `;` or `;` to trigger AI. + if self.ai_callback.is_some() + && starts_with_ai_prefix(&self.line_shadow) + { + let line = + String::from_utf8_lossy(&self.line_shadow).to_string(); + let question = extract_ai_question(&line); + self.line_shadow.clear(); + self.cancel_pty_line = true; + self.state = InterceptorState::AiProcessing; + return StdinAction::TriggerAi(question); + } + self.line_shadow.clear(); + } + 0x03 => { + // Ctrl+C — discard current shadow line + self.line_shadow.clear(); + } + 0x7F | 0x08 => { + // Backspace — pop last UTF-8 character from shadow + pop_last_utf8_char(&mut self.line_shadow); + } + 0x15 => { + // Ctrl+U — clear shadow line + self.line_shadow.clear(); + } + 0x1B => { + // Start of escape sequence — don't add to shadow + self.escape_seq = Some(EscSeqPhase::Start); + } + 0x04 => { + // Ctrl+D — don't add to shadow + } + _ => { + // Regular character — add to shadow + self.line_shadow.push(byte); + } + } + StdinAction::Forward + } + InterceptorState::AiProcessing => StdinAction::EchoLocally, + } + } + + /// Handle a byte in the middle of an escape sequence. + /// For Passthrough: forward all bytes without updating state flags. + /// For AiInput: silently consume the sequence (don't cancel). + fn handle_escape_seq_byte(&mut self, byte: u8, phase: EscSeqPhase) -> StdinAction { + match phase { + EscSeqPhase::Start => match byte { + b'[' => { + // CSI sequence — consume parameter/intermediate bytes + self.escape_seq = Some(EscSeqPhase::Csi); + } + // Two-byte escape (ESC O, ESC (, etc.) — consume final byte + _ => { + // Sequence complete + self.escape_seq = None; + } + }, + EscSeqPhase::Csi => { + if byte >= 0x40 && byte <= 0x7E { + // Final byte — sequence complete + self.escape_seq = None; + } + // Otherwise still consuming parameters (0x30-0x3F) or + // intermediate bytes (0x20-0x2F). + else { + self.escape_seq = Some(EscSeqPhase::Csi); + } + } + } + match self.state { + InterceptorState::Passthrough => StdinAction::Forward, + _ => StdinAction::EchoLocally, + } + } + + /// Feed PTY output data — buffer for error correction context. + pub fn feed_pty_output(&mut self, data: &[u8]) { + self.output_buffer.append(data); + } + + /// Reset state to passthrough after AI processing completes. + pub fn finish_ai(&mut self) { + self.state = InterceptorState::Passthrough; + self.line_shadow.clear(); + self.cancel_pty_line = false; + } + + /// Whether AI is currently processing. + pub fn is_ai_processing(&self) -> bool { + self.state == InterceptorState::AiProcessing + } + + /// Called when the select loop times out with no data. + /// Kept for backward compatibility — line-level detection no longer + /// needs this signal. + pub fn mark_prompt_ready(&mut self) {} + + /// Run the AI callback. The callback returns an AiResponse containing + /// an optional command and display text (or None on error). + pub fn call_ai(&self, question: String) -> Option { + self.ai_callback.as_ref().and_then(|cb| { + let recent = self.recent_output(4000); + cb(AiQuery { + question, + recent_output: recent, + }) + }) + } + + /// Check and clear the cancel_pty_line flag. + /// Returns true when AI was triggered from line-level detection and + /// the PTY has an echoed input line that needs to be cancelled. + pub fn take_cancel_pty_line(&mut self) -> bool { + std::mem::replace(&mut self.cancel_pty_line, false) + } + + /// Get the recent PTY output for error correction context. + pub fn recent_output(&self, max_len: usize) -> String { + let bytes = self.output_buffer.recent(max_len); + String::from_utf8_lossy(&bytes).to_string() + } +} + +/// Extract the question text after `;` or `;` prefix. +fn extract_ai_question(line: &str) -> String { + let trimmed = line.trim(); + if trimmed.starts_with(';') { + trimmed[3..].trim().to_string() + } else if trimmed.starts_with(';') { + trimmed[1..].trim().to_string() + } else { + trimmed.to_string() + } +} + +/// Check whether a byte buffer starts with the ASCII semicolon `;` or the +/// fullwidth semicolon `;` (UTF-8: 0xEF 0xBC 0x9B). +fn starts_with_ai_prefix(line: &[u8]) -> bool { + line.first() == Some(&b';') || line.starts_with(&[0xEF, 0xBC, 0x9B]) +} + +/// Pop the last complete UTF-8 character from a byte buffer. +fn pop_last_utf8_char(buf: &mut Vec) { + // Pop trailing continuation bytes (0x80..0xBF), then the leader byte + while buf.last().map_or(false, |b| b & 0xC0 == 0x80) { + buf.pop(); + } + buf.pop(); +} + +#[cfg(test)] +mod tests { + use super::*; + + fn noop_callback() -> Box { + Box::new(|_q| { + Some(AiResponse { + command: Some("echo test".to_string()), + display_text: String::new(), + followup: None, + ask_user: None, + }) + }) + } + + fn noop_callback_no_cmd() -> Box { + Box::new(|_q| None) + } + + // ---- extract_ai_question tests ---- + + #[test] + fn test_extract_question_ascii_semicolon() { + assert_eq!(extract_ai_question(";ip a"), "ip a"); + } + + #[test] + fn test_extract_question_fullwidth_semicolon() { + assert_eq!(extract_ai_question(";查看IP"), "查看IP"); + } + + #[test] + fn test_extract_question_with_extra_spaces() { + assert_eq!(extract_ai_question("; ip a "), "ip a"); + } + + #[test] + fn test_extract_question_only_semicolon() { + assert_eq!(extract_ai_question(";"), ""); + } + + // ---- State machine tests ---- + + #[test] + fn test_passthrough_forward_normal_bytes() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + assert_eq!(ic.feed_stdin(b'a'), StdinAction::Forward); + assert_eq!(ic.feed_stdin(b'b'), StdinAction::Forward); + } + + #[test] + fn test_semicolon_triggers_ai_on_enter() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + // `;` alone is Forward; AI triggers when Enter is pressed + assert_eq!(ic.feed_stdin(b';'), StdinAction::Forward); + let action = ic.feed_stdin(b'\r'); + assert!(matches!(action, StdinAction::TriggerAi(_))); + assert_eq!(ic.state, InterceptorState::AiProcessing); + assert!(ic.take_cancel_pty_line()); + } + + #[test] + fn test_semicolon_midline_does_not_trigger_ai() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + // pwd; → starts with 'p', not ';' → Enter should be Forward + ic.feed_stdin(b'p'); + ic.feed_stdin(b'w'); + ic.feed_stdin(b'd'); + ic.feed_stdin(b';'); + assert_eq!(ic.feed_stdin(b'\r'), StdinAction::Forward); + } + + #[test] + fn test_ai_input_captures_question() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + ic.feed_stdin(b';'); + ic.feed_stdin(b'i'); + ic.feed_stdin(b'p'); + ic.feed_stdin(b' '); + ic.feed_stdin(b'a'); + if let StdinAction::TriggerAi(q) = ic.feed_stdin(b'\r') { + assert_eq!(q, "ip a"); + } else { + panic!("expected TriggerAi"); + } + } + + #[test] + fn test_no_callback_means_pure_passthrough() { + let mut ic = SessionInterceptor::new(None); + assert_eq!(ic.feed_stdin(b';'), StdinAction::Forward); + assert_eq!(ic.feed_stdin(b'\r'), StdinAction::Forward); + } + + #[test] + fn test_fullwidth_semicolon_triggers_ai() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + // ; = 0xEF 0xBC 0x9B + ic.feed_stdin(0xEF); + ic.feed_stdin(0xBC); + ic.feed_stdin(0x9B); + ic.feed_stdin(b'h'); + ic.feed_stdin(b'i'); + if let StdinAction::TriggerAi(q) = ic.feed_stdin(b'\r') { + assert_eq!(q, "hi"); + } else { + panic!("expected TriggerAi"); + } + } + + #[test] + fn test_ctrl_c_clears_shadow() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + ic.feed_stdin(b';'); + ic.feed_stdin(b'h'); + ic.feed_stdin(b'i'); + assert_eq!(ic.feed_stdin(0x03), StdinAction::Forward); + // shadow was cleared — new line with ; should trigger AI + ic.feed_stdin(b';'); + assert!(matches!(ic.feed_stdin(b'\r'), StdinAction::TriggerAi(_))); + } + + #[test] + fn test_backspace_pops_from_shadow() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + ic.feed_stdin(b'l'); + ic.feed_stdin(b's'); + ic.feed_stdin(0x7F); // backspace removes 's' + ic.feed_stdin(b';'); // now shadow is "l;" — starts with 'l', not ';' + assert_eq!(ic.feed_stdin(b'\r'), StdinAction::Forward); + } + + #[test] + fn test_ctrl_u_clears_shadow() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + ic.feed_stdin(b'a'); + ic.feed_stdin(b'b'); + ic.feed_stdin(0x15); // Ctrl+U clears shadow + ic.feed_stdin(b';'); // now shadow is ";" — triggers AI + assert!(matches!(ic.feed_stdin(b'\r'), StdinAction::TriggerAi(_))); + } + + #[test] + fn test_escape_sequence_not_in_shadow() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + // Simulate arrow key: ESC [ A + ic.feed_stdin(b';'); + ic.feed_stdin(0x1B); // ESC + ic.feed_stdin(b'['); // CSI + ic.feed_stdin(b'A'); // final byte (up arrow) + // shadow is just ";" — triggers AI + assert!(matches!(ic.feed_stdin(b'\r'), StdinAction::TriggerAi(_))); + } + + #[test] + fn test_cancel_pty_line_flag() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + assert!(!ic.take_cancel_pty_line()); + ic.feed_stdin(b';'); + ic.feed_stdin(b'\r'); + // Flag is set but we need to call take_cancel_pty_line + // (normally done by forwarding loop, not in this order) + } + + #[test] + fn test_finish_ai_resets_to_passthrough() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + ic.feed_stdin(b';'); + ic.feed_stdin(b'\r'); + assert!(ic.is_ai_processing()); + ic.finish_ai(); + assert_eq!(ic.state, InterceptorState::Passthrough); + // After finish_ai, a new ; + Enter should trigger again + ic.feed_stdin(b';'); + assert!(matches!(ic.feed_stdin(b'\r'), StdinAction::TriggerAi(_))); + } + + #[test] + fn test_recent_output_captures_pty_data() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + ic.feed_pty_output(b"hello "); + ic.feed_pty_output(b"world\n"); + assert!(ic.recent_output(100).contains("hello world")); + } + + #[test] + fn test_call_ai_returns_command() { + let mut ic = SessionInterceptor::new(Some(noop_callback())); + ic.feed_stdin(b';'); + ic.feed_stdin(b'\r'); + let resp = ic.call_ai("test".to_string()); + assert!(resp.is_some()); + let r = resp.unwrap(); + assert_eq!(r.command, Some("echo test".to_string())); + } + + #[test] + fn test_call_ai_returns_none() { + let ic = SessionInterceptor::new(Some(noop_callback_no_cmd())); + let cmd = ic.call_ai("test".to_string()); + assert!(cmd.is_none()); + } + + // ---- Helper function tests ---- + + #[test] + fn test_starts_with_ai_prefix_ascii() { + assert!(starts_with_ai_prefix(b";hello")); + assert!(starts_with_ai_prefix(b";")); + } + + #[test] + fn test_starts_with_ai_prefix_fullwidth() { + assert!(starts_with_ai_prefix(&[0xEF, 0xBC, 0x9B, b'h', b'i'])); + assert!(starts_with_ai_prefix(&[0xEF, 0xBC, 0x9B])); + } + + #[test] + fn test_starts_with_ai_prefix_negative() { + assert!(!starts_with_ai_prefix(b"hello")); + assert!(!starts_with_ai_prefix(b"ls;pwd")); + assert!(!starts_with_ai_prefix(b"")); + // Incomplete fullwidth semicolon (just first byte) + assert!(!starts_with_ai_prefix(&[0xEF])); + } + + #[test] + fn test_pop_last_utf8_char_ascii() { + let mut buf = vec![b'a', b'b', b'c']; + pop_last_utf8_char(&mut buf); + assert_eq!(buf, b"ab"); + } + + #[test] + fn test_pop_last_utf8_char_cjk() { + // ;= 0xEF 0xBC 0x9B + let mut buf = vec![b'x', 0xEF, 0xBC, 0x9B]; + pop_last_utf8_char(&mut buf); + assert_eq!(buf, b"x"); + } +} diff --git a/crates/aish-shell/src/ai_handler.rs b/crates/aish-shell/src/ai_handler.rs index e964b60..eadee4b 100644 --- a/crates/aish-shell/src/ai_handler.rs +++ b/crates/aish-shell/src/ai_handler.rs @@ -630,7 +630,7 @@ pub struct ErrorCorrectionResult { /// Parse the LLM response for error correction, preferring JSON format. /// Falls back to extracting a ```bash code block if JSON parsing fails. -fn parse_error_correction_response(response: &str) -> ErrorCorrectionResult { +pub(crate) fn parse_error_correction_response(response: &str) -> ErrorCorrectionResult { // Strategy: regex extracts the full content between ```...``` fences, // then serde_json handles actual JSON parsing. This avoids the fragility // of trying to match { brace boundaries } with regex (which breaks on @@ -699,14 +699,14 @@ fn parse_error_correction_response(response: &str) -> ErrorCorrectionResult { } /// Get the current username. -fn whoami() -> String { +pub(crate) fn whoami() -> String { std::env::var("USER") .or_else(|_| std::env::var("USERNAME")) .unwrap_or_else(|_| "user".to_string()) } /// Get the hostname. -fn hostname() -> String { +pub(crate) fn hostname() -> String { std::env::var("HOSTNAME") .or_else(|_| { std::process::Command::new("hostname") @@ -718,7 +718,7 @@ fn hostname() -> String { } /// Get OS information string. -fn os_info() -> String { +pub(crate) fn os_info() -> String { format!( "{} {} ({})", sysinfo::System::name().unwrap_or_default(), diff --git a/crates/aish-shell/src/app.rs b/crates/aish-shell/src/app.rs index 5bea364..2f17599 100644 --- a/crates/aish-shell/src/app.rs +++ b/crates/aish-shell/src/app.rs @@ -8,7 +8,7 @@ use aish_core::{LlmEvent, LlmEventType, MemoryCategory}; use aish_i18n::{t, t_with_args}; use aish_llm::{ langfuse::{LangfuseClient, LangfuseConfig}, - CancellationToken, LlmCallbackResult, LlmSession, + CancellationToken, ChatMessage, LlmCallbackResult, LlmSession, }; use aish_memory::MemoryManager; use aish_security::{SecurityManager, SecurityPolicy}; @@ -468,7 +468,9 @@ impl AishShell { animation_ref.stop(); let ttft = *ttft_value_ref.lock().unwrap(); if ttft >= 0.1 { - println!("\x1b[2m思考: {:.1}s\x1b[0m", ttft); + let mut ttft_args = std::collections::HashMap::new(); + ttft_args.insert("time".to_string(), format!("{:.1}", ttft)); + println!("\x1b[2m{}\x1b[0m", aish_i18n::t_with_args("shell.thinking_time", &ttft_args)); } *thinking_start_ref.lock().unwrap() = None; } @@ -562,12 +564,14 @@ impl AishShell { .map(|s| { let e = s.elapsed().as_secs_f64(); if e >= 1.0 { - format!(" 思考中 {:.1}s", e) + let mut args = std::collections::HashMap::new(); + args.insert("elapsed".to_string(), format!("{:.1}", e)); + format!(" {}", aish_i18n::t_with_args("shell.session.thinking_elapsed", &args)) } else { - " 思考中".to_string() + format!(" {}", aish_i18n::t("shell.session.thinking")) } }) - .unwrap_or_else(|| " 思考中".to_string()); + .unwrap_or_else(|| format!(" {}", aish_i18n::t("shell.session.thinking"))); let prev = reasoning_lines_displayed_ref.load(Ordering::SeqCst); let new_count = 1 + display_lines.len(); @@ -1550,7 +1554,10 @@ impl AishShell { // MutexGuard is dropped before any potential restart_pty() call). let result = { let mut pty = self.lock_pty(); - pty.send_command_interactive(command) + let remote_host = extract_remote_host(command); + let ai_cb = + Self::build_session_ai_callback(&self.config, &self.animation, remote_host.as_deref()); + pty.send_command_interactive(command, ai_cb) }; let (exit_code, cwd, output) = match result { Ok(result) => result, @@ -1806,7 +1813,7 @@ impl AishShell { .pty .lock() .unwrap() - .send_command_interactive(segment) + .send_command_interactive(segment, None) .unwrap_or((-1, self.state.cwd.clone(), String::new())); if !output.is_empty() { @@ -1847,6 +1854,993 @@ impl AishShell { self.state.env_vars.insert(key.clone(), value.clone()); } } + + /// Build an AI callback for session commands (SSH, telnet). + /// Uses the same oracle prompt as local aish (with remote host context), + /// streaming output, thinking animation, and ShellRenderer for display. + /// Build a followup closure that can chain itself for multi-round tool use. + /// Each invocation: call LLM with command output → render analysis → if + /// the LLM suggests another command, return Some(AiResponse) with another + /// followup closure (same builder, shared history). + fn build_followup_closure( + api_base: &str, + api_key: &str, + model: &str, + temperature: Option, + max_tokens: Option, + system_msg: &str, + original_question: &str, + animation: &Arc, + history: &Arc>>, + ) -> Box { + let api_base_f = api_base.to_string(); + let api_key_f = api_key.to_string(); + let model_f = model.to_string(); + let system_msg_f = system_msg.to_string(); + let question_f = original_question.to_string(); + let anim_f = animation.clone(); + let history_f = history.clone(); + + Box::new(move |output: &str| -> Option { + let followup_prompt = format!( + "I ran the command on the remote host. Here is the output:\n\ + ```\n{}\n```\n\n\ + Original question: {}\n\ + Please analyze the command output. If further action is needed, \ + suggest the next bash command in a ```bash code block. \ + If no further action is needed, just provide a summary.", + output, question_f + ); + + // Channel for ask_user communication + let (event_tx, event_rx) = std::sync::mpsc::channel::(); + let (answer_tx, answer_rx) = + std::sync::mpsc::channel::(); + + // Done signal: LLM thread sends when all rendering is complete. + let (llm_done_tx, llm_done_rx) = std::sync::mpsc::channel::<()>(); + + // Cancellation support for Ctrl+C + let (token_tx, token_rx) = + std::sync::mpsc::channel::>(); + let cancelled = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let cancelled_cb = cancelled.clone(); + + // Shared reasoning state for cleanup after cancellation + let reasoning_active_main = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let reasoning_lines_main = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let reasoning_active_cb = reasoning_active_main.clone(); + let reasoning_lines_cb = reasoning_lines_main.clone(); + + let followup_start = std::time::Instant::now(); + + anim_f.start(&t("shell.status.thinking")); + let history_snapshot = history_f.lock().unwrap().clone(); + + // Spawn LLM thread with ChannelAskUserTool + let api_base_th = api_base_f.clone(); + let api_key_th = api_key_f.clone(); + let model_th = model_f.clone(); + let system_msg_th = system_msg_f.clone(); + let anim_th = anim_f.clone(); + let followup_prompt_th = followup_prompt.clone(); + let question_th = question_f.clone(); + let conversation_history_th = history_f.clone(); + + std::thread::spawn(move || { + let rt = match tokio::runtime::Runtime::new() { + Ok(rt) => rt, + Err(e) => { + let _ = event_tx.send(aish_pty::AiEvent::Done(None)); + let msg = format!("\r\n\x1b[31mFollowup error: {}\x1b[0m\r\n", e); + unsafe { + nix::libc::write( + nix::libc::STDOUT_FILENO, + msg.as_ptr() as *const nix::libc::c_void, + msg.len(), + ); + } + return; + } + }; + let mut session = LlmSession::new( + &api_base_th, &api_key_th, &model_th, + temperature, max_tokens, + ); + // Send cancellation token to main thread + let _ = token_tx.send(session.cancellation_token_arc()); + let anim = anim_th.clone(); + let reasoning_active = reasoning_active_cb.clone(); + let reasoning_frame = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let reasoning_lines_displayed = reasoning_lines_cb.clone(); + let reasoning_buf = std::sync::Mutex::new(String::new()); + let thinking_start_followup = std::sync::Arc::new(std::sync::Mutex::new( + Some(std::time::Instant::now()) + )); + session.set_event_callback(std::sync::Arc::new(move |event| { + // Helper: clear reasoning overlay from terminal + let clear_reasoning = || { + if reasoning_active.swap(false, std::sync::atomic::Ordering::SeqCst) { + let prev = reasoning_lines_displayed.swap(0, std::sync::atomic::Ordering::SeqCst); + if prev > 0 { + use std::io::Write; + print!("\x1b[{}A", prev); + for _ in 0..prev { print!("\r\x1b[K\n"); } + print!("\x1b[{}A", prev); + let _ = std::io::stdout().flush(); + } + reasoning_buf.lock().unwrap().clear(); + } + }; + // Bail out if cancelled + if cancelled_cb.load(std::sync::atomic::Ordering::SeqCst) { + anim.stop(); + clear_reasoning(); + return None; + } + use aish_core::LlmEventType; + match event.event_type { + LlmEventType::GenerationStart => { + anim.stop(); + reasoning_active.store(false, std::sync::atomic::Ordering::SeqCst); + reasoning_frame.store(0, std::sync::atomic::Ordering::SeqCst); + reasoning_lines_displayed.store(0, std::sync::atomic::Ordering::SeqCst); + reasoning_buf.lock().unwrap().clear(); + anim.start(&t("shell.status.thinking")); + } + LlmEventType::GenerationEnd => { + anim.stop(); + clear_reasoning(); + } + LlmEventType::ContentDelta => { + if let Some(delta) = + event.data.get("delta").and_then(|d| d.as_str()) + { + if !delta.is_empty() { + anim.stop(); + clear_reasoning(); + } + } + } + LlmEventType::ReasoningDelta => { + if let Some(delta) = + event.data.get("delta").and_then(|d| d.as_str()) + { + if !delta.is_empty() { + anim.stop(); + if !reasoning_active.load(std::sync::atomic::Ordering::SeqCst) { + reasoning_active.store(true, std::sync::atomic::Ordering::SeqCst); + } + let mut buf = reasoning_buf.lock().unwrap(); + buf.push_str(delta); + let all_lines: Vec<&str> = + buf.lines().filter(|l| !l.trim().is_empty()).collect(); + let display_lines: Vec<&str> = all_lines + .iter().rev().take(2) + .collect::>().into_iter().rev().copied().collect(); + let max_cols = 76usize; + let frame = reasoning_frame.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let spinner = DOTS_FRAMES[frame % DOTS_FRAMES.len()]; + let elapsed_str = thinking_start_followup.lock().unwrap() + .map(|s| { + let e = s.elapsed().as_secs_f64(); + if e >= 1.0 { + let mut args = std::collections::HashMap::new(); + args.insert("elapsed".to_string(), format!("{:.1}", e)); + format!(" {}", aish_i18n::t_with_args("shell.session.thinking_elapsed", &args)) + } else { + format!(" {}", aish_i18n::t("shell.session.thinking")) + } + }) + .unwrap_or_else(|| format!(" {}", aish_i18n::t("shell.session.thinking"))); + let prev = reasoning_lines_displayed.load(std::sync::atomic::Ordering::SeqCst); + let new_count = 1 + display_lines.len(); + if prev > 0 { print!("\x1b[{}A", prev); } + if display_lines.is_empty() { + print!("\r\x1b[K\x1b[90m{}{}...\x1b[0m\n", spinner, elapsed_str); + } else { + print!("\r\x1b[K\x1b[90m{}{}\x1b[0m\n", spinner, elapsed_str); + } + for line in &display_lines { + let truncated = truncate_display_width(line.trim(), max_cols); + print!("\r\x1b[K\x1b[90m{}\x1b[0m\n", truncated); + } + for _ in new_count..prev { print!("\r\x1b[K\n"); } + if prev > new_count { print!("\x1b[{}A", prev - new_count); } + reasoning_lines_displayed.store(new_count, std::sync::atomic::Ordering::SeqCst); + reasoning_active.store(true, std::sync::atomic::Ordering::SeqCst); + let _ = std::io::stdout().flush(); + } + } + } + LlmEventType::ReasoningEnd => { + clear_reasoning(); + } + LlmEventType::Error => { + anim.stop(); + clear_reasoning(); + let err = event.data.get("error") + .or_else(|| event.data.get("error_message")) + .and_then(|e| e.as_str()) + .unwrap_or("Unknown error"); + let msg = format!( + "\r\n\x1b[31mFollowup LLM error: {}\x1b[0m\r\n", err + ); + unsafe { + nix::libc::write( + nix::libc::STDOUT_FILENO, + msg.as_ptr() as *const nix::libc::c_void, + msg.len(), + ); + } + } + _ => {} + } + None + })); + // Register channel-based tools for SSH followup + session.register_tool(Box::new( + aish_tools::ChannelBashTool::new(event_tx.clone()), + )); + session.register_tool(Box::new( + aish_tools::ChannelAskUserTool::new(event_tx.clone(), answer_rx), + )); + let result = rt.block_on(async { + session + .process_input(&followup_prompt_th, &history_snapshot, Some(&system_msg_th), true) + .await + }); + let text = result.ok(); + + // Update conversation history + { + let mut h = conversation_history_th.lock().unwrap(); + h.push(ChatMessage::user(&followup_prompt)); + if let Some(ref t) = text { + h.push(ChatMessage::assistant(t)); + } + let excess = h.len().saturating_sub(50); + if excess > 0 { h.drain(..excess); } + } + + // Render analysis + if let Some(ref t) = text { + if !t.trim().is_empty() { + let _ = std::io::stdout().flush(); + let mut renderer = crate::renderer::ShellRenderer::new(); + renderer.render_separator(); + renderer.render_markdown(t); + let _ = std::io::stdout().flush(); + } + } + + // Build AiResponse with followup if another command was suggested + let next_cmd = text.as_ref().and_then(|t| extract_bash_command(t)); + let ai_response = if let Some(_) = next_cmd { + let next_followup = Self::build_followup_closure( + &api_base_th, &api_key_th, &model_th, temperature, max_tokens, + &system_msg_th, &question_th, &anim_th, &conversation_history_th, + ); + Some(aish_pty::AiResponse { + command: next_cmd, + display_text: String::new(), + followup: Some(next_followup), + ask_user: None, + }) + } else { + None + }; + let _ = event_tx.send(aish_pty::AiEvent::Done(ai_response)); + let _ = llm_done_tx.send(()); // signal rendering complete + }); + + // Wait for result with Ctrl+C cancellation support + let session_cancel_token = token_rx.recv_timeout(std::time::Duration::from_secs(5)).ok(); + let result = loop { + match event_rx.try_recv() { + Ok(aish_pty::AiEvent::Done(ai_response)) => { + break ai_response; + } + Ok(aish_pty::AiEvent::AskUser(request)) => { + break Some(aish_pty::AiResponse { + command: None, + display_text: String::new(), + followup: None, + ask_user: Some(( + request, + aish_pty::AskUserChannel { + answer_sender: answer_tx.clone(), + event_receiver: event_rx, + }, + )), + }); + } + Ok(aish_pty::AiEvent::BashExec { command, output_sender }) => { + let osender = output_sender.clone(); + let done_rx = std::sync::Mutex::new(llm_done_rx); + let followup: Box = Box::new( + move |captured_output: &str| -> Option { + let _ = osender.send(captured_output.to_string()); + // Block until LLM thread finishes all rendering + let _ = done_rx.lock().unwrap().recv_timeout( + std::time::Duration::from_secs(120), + ); + None + }, + ); + break Some(aish_pty::AiResponse { + command: Some(command), + display_text: String::new(), + followup: Some(followup), + ask_user: None, + }); + } + Err(std::sync::mpsc::TryRecvError::Disconnected) => break None, + Err(std::sync::mpsc::TryRecvError::Empty) => {} + } + // Check for Ctrl+C on stdin (non-blocking) + let mut rfds: nix::libc::fd_set = unsafe { std::mem::zeroed() }; + unsafe { + nix::libc::FD_ZERO(&mut rfds); + nix::libc::FD_SET(nix::libc::STDIN_FILENO, &mut rfds); + } + let mut tv = nix::libc::timeval { tv_sec: 0, tv_usec: 100_000 }; + let sel = unsafe { + nix::libc::select( + nix::libc::STDIN_FILENO + 1, + &mut rfds, + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut tv, + ) + }; + if sel > 0 { + let mut byte = [0u8; 1]; + match unsafe { + nix::libc::read( + nix::libc::STDIN_FILENO, + byte.as_mut_ptr() as *mut nix::libc::c_void, + 1, + ) + } { + 1 => { + if byte[0] == 0x03 { + cancelled.store(true, std::sync::atomic::Ordering::SeqCst); + if let Some(ref token) = session_cancel_token { + token.cancel(); + } + anim_f.stop(); + let msg = format!("\r\n\x1b[33m{}\x1b[0m", t("shell.command_cancelled")); + unsafe { + nix::libc::write( + nix::libc::STDOUT_FILENO, + msg.as_ptr() as *mut nix::libc::c_void, + msg.len(), + ); + } + break None; + } + } + _ => {} + } + } + // Check timeout (60s) + if followup_start.elapsed() > std::time::Duration::from_secs(60) { + anim_f.stop(); + let msg = b"\r\n\x1b[31mLLM timeout (60s)\x1b[0m"; + unsafe { + nix::libc::write( + nix::libc::STDOUT_FILENO, + msg.as_ptr() as *mut nix::libc::c_void, + msg.len(), + ); + } + break None; + } + }; + + anim_f.stop(); + + // Clear residual reasoning overlay + if reasoning_active_main.swap(false, std::sync::atomic::Ordering::SeqCst) { + let prev = reasoning_lines_main.load(std::sync::atomic::Ordering::SeqCst); + if prev > 0 { + use std::io::Write; + print!("\x1b[{}A", prev); + for _ in 0..prev { print!("\r\x1b[K\n"); } + print!("\x1b[{}A", prev); + reasoning_lines_main.store(0, std::sync::atomic::Ordering::SeqCst); + let _ = std::io::stdout().flush(); + } + } + + result + }) + } + + fn build_session_ai_callback( + config: &aish_config::ConfigModel, + animation: &Arc, + remote_host: Option<&str>, + ) -> Option> { + let api_base = config.api_base.clone(); + let api_key = config.api_key.clone(); + let model = config.model.clone(); + let temperature = config.temperature; + let max_tokens = config.max_tokens; + let animation = animation.clone(); + + // Build oracle system prompt with remote host context + let mut prompt_manager = aish_prompts::PromptManager::default_dir(); + prompt_manager.load_all(); + let role_prompt = prompt_manager.get("role").to_string(); + let mut vars = std::collections::HashMap::new(); + vars.insert("role_prompt".to_string(), role_prompt); + vars.insert("username".to_string(), crate::ai_handler::whoami()); + vars.insert("hostname".to_string(), crate::ai_handler::hostname()); + vars.insert("os_info".to_string(), crate::ai_handler::os_info()); + vars.insert("cwd".to_string(), "~".to_string()); + vars.insert("system_info".to_string(), String::new()); + vars.insert("memory_context".to_string(), String::new()); + vars.insert("skill_list".to_string(), String::new()); + let mut system_msg = prompt_manager.render("oracle", &vars); + if let Some(host) = remote_host { + system_msg.push_str(&format!( + "\n\n**SSH Remote Session Context (overrides tool list above):** \ + \n- The user is connected to a remote host **{host}** via SSH. \ + \n- **Available tools:** `bash` and `ask_user` only. \ + \n- **DO NOT** call python_exec, read_file, write_file, edit_file, \ + grep, glob, or any other tool — they do NOT exist in this session. \ + \n- `bash` tool runs commands on the remote host. The command will be \ + shown to the user for confirmation before execution. After execution, \ + the output will be automatically returned to you for analysis. \ + \n- `ask_user` asks the user a clarifying question (text_input or choice). \ + \n- **For reading/writing/searching files:** use `bash` tool with \ + `cat`, `head`, `tail`, `echo`, `tee`, `grep`, `find`, `awk`, etc. \ + \n- When the user asks to run a command, execute it, or check something, \ + call the `bash` tool directly — do NOT just show the command in a code block." + )); + } + + // Pre-compute static values for error correction template + let ec_username = crate::ai_handler::whoami(); + let ec_os_info = crate::ai_handler::os_info(); + let remote_host_owned = remote_host.map(|s| s.to_string()); + let conversation_history: Arc>> = + Arc::new(Mutex::new(Vec::new())); + + Some(Box::new(move |query: aish_pty::AiQuery| { + // Detect error correction mode: user typed just `;` with no + // question after a command failure. In this case, the recent + // output contains the error and we use a dedicated prompt. + let is_error_correction = + query.question.is_empty() && !query.recent_output.is_empty(); + + let (context, effective_system_msg) = if is_error_correction { + // Use aish's cmd_error template (same as local aish error correction) + let mut pm = aish_prompts::PromptManager::default_dir(); + pm.load_all(); + + // Extract the actual failed command from bash error output + let failed_cmd = extract_failed_command(&query.recent_output); + + let stderr_section = if query.recent_output.is_empty() { + String::new() + } else { + let s = &query.recent_output; + let preview = if s.len() > 2048 { + let mut end = 2048; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + &s[..end] + } else { + s + }; + format!("\n**Command Output:**\n```\n{}\n```", preview) + }; + + let mut ec_vars = std::collections::HashMap::new(); + ec_vars.insert("username".to_string(), ec_username.clone()); + ec_vars.insert("os_info".to_string(), ec_os_info.clone()); + ec_vars.insert("command".to_string(), failed_cmd.clone()); + ec_vars.insert("exit_code".to_string(), "1".to_string()); + ec_vars.insert("stderr_section".to_string(), stderr_section); + let mut sys = pm.render("cmd_error", &ec_vars); + + // Append SSH host context + if let Some(ref host) = remote_host_owned { + sys.push_str(&format!( + "\n\n**Important:** The command was executed on remote host **{}** via SSH.", + host + )); + } + + // Same XML format as local aish's handle_error_correction + let ctx = format!( + "\nCommand: {}\nExit code: 1\n\n\n\ + Please analyze the error and suggest a fix. \ + Check the recent terminal output above for the actual error output.", + failed_cmd + ); + (ctx, sys) + } else if query.recent_output.is_empty() { + (query.question.clone(), system_msg.clone()) + } else { + // Put user question first, recent output as reference context + let ctx = format!( + "{}\n\nFor reference, recent terminal output:\n{}", + query.question, query.recent_output + ); + (ctx, system_msg.clone()) + }; + + // Start thinking spinner + animation.start(&t("shell.status.thinking")); + let thinking_start = std::sync::Arc::new(std::sync::Mutex::new( + Some(std::time::Instant::now()) + )); + + // Channel for ask_user communication between LLM thread and callback + let (event_tx, event_rx) = std::sync::mpsc::channel::(); + let (answer_tx, answer_rx) = + std::sync::mpsc::channel::(); + + // Channel to send the session's cancellation token back to the + // main thread so Ctrl+C can cancel the in-flight LLM request. + let (token_tx, token_rx) = + std::sync::mpsc::channel::>(); + let cancelled = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let cancelled_t = cancelled.clone(); + + // Done signal: LLM thread sends when all rendering is complete. + // The followup closure waits on this so the forwarding loop only + // requests a new PTY prompt AFTER the LLM output is on screen. + let (llm_done_tx, llm_done_rx) = std::sync::mpsc::channel::<()>(); + + // Shared reasoning state — needed by both the LLM event callback + // (inside the thread) and the main thread (to clear residual + // reasoning lines before ask_user / normal completion). + let reasoning_active_main = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let reasoning_lines_main = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let reasoning_active_cb = reasoning_active_main.clone(); + let reasoning_lines_cb = reasoning_lines_main.clone(); + + let api_base_t = api_base.clone(); + let api_key_f = api_key.clone(); + let model_f = model.clone(); + let animation_t = animation.clone(); + let thinking_start_thread = thinking_start.clone(); + let context_messages_t = conversation_history.lock().unwrap().clone(); + let context_for_thread = context.clone(); + let conversation_history_t = conversation_history.clone(); + let system_msg_t = effective_system_msg.clone(); + let query_question_t = query.question.clone(); + let api_base_th = api_base.clone(); + let api_key_th = api_key.clone(); + let model_th = model.clone(); + let animation_th = animation.clone(); + let conversation_history_th = conversation_history.clone(); + let system_msg_th = effective_system_msg.clone(); + + std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let result = rt.block_on(async { + let mut session = LlmSession::new( + &api_base_t, + &api_key_f, + &model_f, + Some(temperature), + max_tokens, + ); + + // Send cancellation token to main thread + let _ = token_tx.send(session.cancellation_token_arc()); + // Streaming event callback: only show reasoning overlay, + // collect content for formatted rendering after completion + let anim = animation_t.clone(); + let cancelled_cb = cancelled_t.clone(); + let reasoning_active = reasoning_active_cb.clone(); + let reasoning_frame = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let reasoning_lines_displayed = reasoning_lines_cb.clone(); + let reasoning_buf = std::sync::Mutex::new(String::new()); + let thinking_start_r = thinking_start_thread.clone(); + session.set_event_callback(std::sync::Arc::new(move |event| { + // Helper: clear reasoning overlay from terminal + let clear_reasoning = || { + if reasoning_active.swap(false, std::sync::atomic::Ordering::SeqCst) { + let prev = reasoning_lines_displayed.swap(0, std::sync::atomic::Ordering::SeqCst); + if prev > 0 { + use std::io::Write; + print!("\x1b[{}A", prev); + for _ in 0..prev { print!("\r\x1b[K\n"); } + print!("\x1b[{}A", prev); + let _ = std::io::stdout().flush(); + } + reasoning_buf.lock().unwrap().clear(); + } + }; + // Bail out if cancelled + if cancelled_cb.load(std::sync::atomic::Ordering::SeqCst) { + anim.stop(); + clear_reasoning(); + return None; + } + use aish_core::LlmEventType; + match event.event_type { + LlmEventType::GenerationStart => { + anim.stop(); + reasoning_active.store(false, std::sync::atomic::Ordering::SeqCst); + reasoning_frame.store(0, std::sync::atomic::Ordering::SeqCst); + reasoning_lines_displayed.store(0, std::sync::atomic::Ordering::SeqCst); + reasoning_buf.lock().unwrap().clear(); + anim.start(&t("shell.status.thinking")); + } + LlmEventType::GenerationEnd => { + anim.stop(); + clear_reasoning(); + } + LlmEventType::ContentDelta => { + if let Some(delta) = + event.data.get("delta").and_then(|d| d.as_str()) + { + if !delta.is_empty() { + anim.stop(); + clear_reasoning(); + } + } + } + LlmEventType::ReasoningDelta => { + if let Some(delta) = + event.data.get("delta").and_then(|d| d.as_str()) + { + if !delta.is_empty() { + anim.stop(); + if !reasoning_active.load(std::sync::atomic::Ordering::SeqCst) { + reasoning_active.store(true, std::sync::atomic::Ordering::SeqCst); + } + let mut buf = reasoning_buf.lock().unwrap(); + buf.push_str(delta); + let all_lines: Vec<&str> = + buf.lines().filter(|l| !l.trim().is_empty()).collect(); + let display_lines: Vec<&str> = all_lines + .iter().rev().take(2) + .collect::>().into_iter().rev().copied().collect(); + let max_cols = 76usize; + let frame = reasoning_frame.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let spinner = DOTS_FRAMES[frame % DOTS_FRAMES.len()]; + let elapsed_str = thinking_start_r.lock().unwrap() + .map(|s| { + let e = s.elapsed().as_secs_f64(); + if e >= 1.0 { + let mut args = std::collections::HashMap::new(); + args.insert("elapsed".to_string(), format!("{:.1}", e)); + format!(" {}", aish_i18n::t_with_args("shell.session.thinking_elapsed", &args)) + } else { + format!(" {}", aish_i18n::t("shell.session.thinking")) + } + }) + .unwrap_or_else(|| format!(" {}", aish_i18n::t("shell.session.thinking"))); + let prev = reasoning_lines_displayed.load(std::sync::atomic::Ordering::SeqCst); + let new_count = 1 + display_lines.len(); + if prev > 0 { print!("\x1b[{}A", prev); } + if display_lines.is_empty() { + print!("\r\x1b[K\x1b[90m{}{}...\x1b[0m\n", spinner, elapsed_str); + } else { + print!("\r\x1b[K\x1b[90m{}{}\x1b[0m\n", spinner, elapsed_str); + } + for line in &display_lines { + let truncated = truncate_display_width(line.trim(), max_cols); + print!("\r\x1b[K\x1b[90m{}\x1b[0m\n", truncated); + } + for _ in new_count..prev { print!("\r\x1b[K\n"); } + if prev > new_count { print!("\x1b[{}A", prev - new_count); } + reasoning_lines_displayed.store(new_count, std::sync::atomic::Ordering::SeqCst); + reasoning_active.store(true, std::sync::atomic::Ordering::SeqCst); + let _ = std::io::stdout().flush(); + } + } + } + LlmEventType::ReasoningEnd => { + clear_reasoning(); + } + LlmEventType::Error => { + anim.stop(); + clear_reasoning(); + let error_msg = event + .data + .get("error") + .or_else(|| event.data.get("error_message")) + .and_then(|e| e.as_str()) + .unwrap_or("Unknown error"); + eprintln!("\x1b[31mLLM error: {}\x1b[0m", error_msg); + } + _ => {} + } + None + })); + + // Register channel-based tools for SSH sessions + session.register_tool(Box::new( + aish_tools::ChannelBashTool::new(event_tx.clone()), + )); + session.register_tool(Box::new( + aish_tools::ChannelAskUserTool::new(event_tx.clone(), answer_rx), + )); + + session + .process_input(&context_for_thread, &context_messages_t, Some(&system_msg_t), true) + .await + }); + if cancelled_t.load(std::sync::atomic::Ordering::SeqCst) { + return; + } + let text = result.ok(); + + // Update conversation history + { + let mut h = conversation_history_t.lock().unwrap(); + h.push(ChatMessage::user(&context_for_thread)); + if let Some(ref response_text) = text { + h.push(ChatMessage::assistant(response_text)); + } + let excess = h.len().saturating_sub(50); + if excess > 0 { h.drain(..excess); } + } + + // Build AiResponse from text + let ai_response = match text { + Some(ref t) if is_error_correction => { + // Render description + let ec_result = + crate::ai_handler::parse_error_correction_response(t); + if let Some(ref desc) = ec_result.description { + if !desc.trim().is_empty() { + let _ = std::io::stdout().flush(); + let mut renderer = crate::renderer::ShellRenderer::new(); + renderer.render_separator(); + renderer.render_markdown(desc); + let _ = std::io::stdout().flush(); + } + } + Some(aish_pty::AiResponse { + command: ec_result.command, + display_text: String::new(), + followup: None, + ask_user: None, + }) + } + Some(t) => { + // Render formatted markdown + if !t.trim().is_empty() { + let _ = std::io::stdout().flush(); + let mut renderer = crate::renderer::ShellRenderer::new(); + renderer.render_separator(); + renderer.render_markdown(t.trim()); + let elapsed = thinking_start_thread.lock().unwrap() + .map(|s| s.elapsed().as_secs_f64()) + .unwrap_or(0.0); + if elapsed >= 0.1 { + let mut elapsed_args = std::collections::HashMap::new(); + elapsed_args.insert("time".to_string(), format!("{:.1}", elapsed)); + println!("\x1b[2m{}\x1b[0m", aish_i18n::t_with_args("shell.thinking_time", &elapsed_args)); + } + renderer.render_separator(); + let _ = std::io::stdout().flush(); + } + let command = extract_bash_command(&t); + let followup = command.as_ref().map(|_cmd| { + Self::build_followup_closure( + &api_base_th, &api_key_th, &model_th, + Some(temperature), max_tokens, + &system_msg_th, &query_question_t, &animation_th, &conversation_history_th, + ) + }); + Some(aish_pty::AiResponse { + command, + display_text: String::new(), + followup, + ask_user: None, + }) + } + None => None, + }; + let _ = event_tx.send(aish_pty::AiEvent::Done(ai_response)); + let _ = llm_done_tx.send(()); // signal rendering complete + }); + + // Wait for result with Ctrl+C cancellation support + // Also handles ask_user events from the LLM thread + enum CallbackEvent { + Done(Option), + AskUser(aish_pty::AskUserRequest, aish_pty::AskUserChannel), + BashExec { + command: String, + output_sender: std::sync::mpsc::Sender, + event_receiver: std::sync::mpsc::Receiver, + }, + } + // Receive the session's cancellation token (sent by the LLM thread) + let session_cancel_token = token_rx.recv_timeout(std::time::Duration::from_secs(5)).ok(); + let cb_event = loop { + match event_rx.try_recv() { + Ok(aish_pty::AiEvent::Done(ai_response)) => { + break Some(CallbackEvent::Done(ai_response)); + } + Ok(aish_pty::AiEvent::AskUser(request)) => { + break Some(CallbackEvent::AskUser( + request, + aish_pty::AskUserChannel { + answer_sender: answer_tx.clone(), + event_receiver: event_rx, + }, + )); + } + Ok(aish_pty::AiEvent::BashExec { command, output_sender }) => { + break Some(CallbackEvent::BashExec { + command, + output_sender, + event_receiver: event_rx, + }); + } + Err(std::sync::mpsc::TryRecvError::Disconnected) => break None, + Err(std::sync::mpsc::TryRecvError::Empty) => {} + } + // Check for Ctrl+C on stdin (non-blocking) + let mut rfds: nix::libc::fd_set = unsafe { std::mem::zeroed() }; + unsafe { + nix::libc::FD_ZERO(&mut rfds); + nix::libc::FD_SET(nix::libc::STDIN_FILENO, &mut rfds); + } + let mut tv = nix::libc::timeval { tv_sec: 0, tv_usec: 100_000 }; // 100ms + let sel = unsafe { + nix::libc::select(nix::libc::STDIN_FILENO + 1, &mut rfds, std::ptr::null_mut(), std::ptr::null_mut(), &mut tv) + }; + if sel > 0 { + // Read one byte to check for Ctrl+C + let mut byte = [0u8; 1]; + match unsafe { nix::libc::read(nix::libc::STDIN_FILENO, byte.as_mut_ptr() as *mut nix::libc::c_void, 1) } { + 1 => { + if byte[0] == 0x03 { + // Ctrl+C pressed — cancel the LLM request + cancelled.store(true, std::sync::atomic::Ordering::SeqCst); + if let Some(ref token) = session_cancel_token { + token.cancel(); + } + animation.stop(); + println!("\r\n\x1b[33m{}\x1b[0m", t("shell.command_cancelled")); + break None; + } + // Non-Ctrl-C byte during AI processing — discard. + // The user shouldn't be typing during AI processing; + // any stray bytes are not recoverable here. + } + _ => {} + } + } + // Check timeout (60s) + if thinking_start.lock().unwrap().map_or(false, |s| s.elapsed() > std::time::Duration::from_secs(60)) { + animation.stop(); + eprintln!("\x1b[31mLLM timeout (60s)\x1b[0m"); + break None; + } + }; + + animation.stop(); + + // Clear any residual reasoning lines left on screen (the LLM + // event callback may have shown reasoning deltas that were not + // erased because GenerationEnd / ReasoningEnd haven't fired yet). + if reasoning_active_main.swap(false, std::sync::atomic::Ordering::SeqCst) { + let prev = reasoning_lines_main.load(std::sync::atomic::Ordering::SeqCst); + if prev > 0 { + use std::io::Write; + print!("\x1b[{}A", prev); + for _ in 0..prev { print!("\r\x1b[K\n"); } + print!("\x1b[{}A", prev); + reasoning_lines_main.store(0, std::sync::atomic::Ordering::SeqCst); + let _ = std::io::stdout().flush(); + } + } + + // Handle the callback event + match cb_event { + // Ask_user — return response with ask_user channel + Some(CallbackEvent::AskUser(request, channel)) => { + Some(aish_pty::AiResponse { + command: None, + display_text: String::new(), + followup: None, + ask_user: Some((request, channel)), + }) + } + // Bash_exec — return command for execution on remote host, + // with multi-round chaining support. + Some(CallbackEvent::BashExec { command, output_sender, event_receiver }) => { + let shared_event_rx = std::sync::Arc::new(std::sync::Mutex::new( + Some(event_receiver), + )); + let shared_done_rx = std::sync::Arc::new(std::sync::Mutex::new( + Some(llm_done_rx), + )); + let answer_tx_f = answer_tx.clone(); + + fn make_chain_followup( + event_rx: std::sync::Arc>, + >>, + done_rx: std::sync::Arc>, + >>, + answer_tx: std::sync::mpsc::Sender, + output_sender: std::sync::mpsc::Sender, + ) -> Box { + Box::new(move |captured_output: &str| -> Option { + let _ = output_sender.send(captured_output.to_string()); + + let rx = match event_rx.lock().unwrap().take() { + Some(rx) => rx, + None => return None, + }; + + match rx.recv_timeout(std::time::Duration::from_secs(120)) { + Ok(aish_pty::AiEvent::BashExec { command, output_sender: new_sender }) => { + *event_rx.lock().unwrap() = Some(rx); + Some(aish_pty::AiResponse { + command: Some(command), + display_text: String::new(), + followup: Some(make_chain_followup( + event_rx.clone(), + done_rx.clone(), + answer_tx.clone(), + new_sender, + )), + ask_user: None, + }) + } + Ok(aish_pty::AiEvent::AskUser(request)) => { + Some(aish_pty::AiResponse { + command: None, + display_text: String::new(), + followup: None, + ask_user: Some(( + request, + aish_pty::AskUserChannel { + answer_sender: answer_tx.clone(), + event_receiver: rx, + }, + )), + }) + } + Ok(aish_pty::AiEvent::Done(_)) | Err(_) => { + if let Some(drx) = done_rx.lock().unwrap().take() { + let _ = drx.recv_timeout(std::time::Duration::from_secs(120)); + } + None + } + } + }) + } + + let followup = make_chain_followup( + shared_event_rx, + shared_done_rx, + answer_tx_f, + output_sender, + ); + Some(aish_pty::AiResponse { + command: Some(command), + display_text: String::new(), + followup: Some(followup), + ask_user: None, + }) + } + // Normal completion — AiResponse already built by LLM thread + Some(CallbackEvent::Done(ai_response)) => ai_response, + None => None, + } + })) + } } /// Cached regex for stripping complete XML tags from tool output. @@ -2102,6 +3096,106 @@ fn print_md(text: &str) { renderer.render_markdown(text); } +/// Extract the first ```bash code block from AI response text. +/// Extract the remote host from an SSH/telnet command. +/// e.g. "ssh root@10.10.17.243" → "root@10.10.17.243" +/// e.g. "ssh -p 2222 user@example.com" → "user@example.com" +fn extract_remote_host(command: &str) -> Option { + let parts: Vec<&str> = command.split_whitespace().collect(); + if parts.is_empty() { + return None; + } + let cmd = parts[0]; + if !matches!(cmd, "ssh" | "telnet" | "mosh" | "sftp" | "nc" | "netcat") { + return None; + } + // Find the last argument that looks like a host (contains @ or is a hostname/IP) + for part in parts.iter().skip(1).rev() { + if part.starts_with('-') { + continue; + } + // user@host or just host + if part.contains('@') || part.contains('.') || !part.contains(char::is_whitespace) { + return Some(part.to_string()); + } + } + None +} + +/// Extract the failed command from PTY output after a command error. +/// Strategy 1: Find the full command from the prompt line just before the +/// bash error (preserves pipes, args, etc.). +/// Strategy 2: Extract the command name from the bash error message. +fn extract_failed_command(output: &str) -> String { + static ANSI_RE: std::sync::OnceLock = std::sync::OnceLock::new(); + let re = ANSI_RE.get_or_init(|| regex::Regex::new(r"\x1b\[[0-9;?]*[a-zA-Z]").unwrap()); + let clean = re.replace_all(output, "").to_string(); + let lines: Vec<&str> = clean.lines().collect(); + + // Find the shell error line + for (i, line) in lines.iter().enumerate() { + let trimmed = line.trim(); + let is_shell_error = trimmed.starts_with("-bash: ") + || trimmed.starts_with("bash: ") + || trimmed.starts_with("-ksh: ") + || trimmed.starts_with("ksh: ") + || trimmed.starts_with("-zsh: ") + || trimmed.starts_with("zsh: "); + + if is_shell_error && i > 0 { + // Look at the line before the error — it should be the prompt + command + let prev = lines[i - 1].trim(); + // Extract command after the last "# " or "$ " (common prompt endings) + if let Some(idx) = prev.rfind("# ") { + let full_cmd = prev[idx + 2..].trim(); + if !full_cmd.is_empty() { + return full_cmd.to_string(); + } + } + if let Some(idx) = prev.rfind("$ ") { + let full_cmd = prev[idx + 2..].trim(); + if !full_cmd.is_empty() { + return full_cmd.to_string(); + } + } + } + } + + // Fallback: extract command name from the error message itself + for line in lines.iter().rev() { + let trimmed = line.trim(); + let rest = trimmed + .strip_prefix("-bash: ") + .or_else(|| trimmed.strip_prefix("bash: ")) + .or_else(|| trimmed.strip_prefix("-ksh: ")) + .or_else(|| trimmed.strip_prefix("ksh: ")) + .or_else(|| trimmed.strip_prefix("-zsh: ")) + .or_else(|| trimmed.strip_prefix("zsh: ")); + if let Some(rest) = rest { + if let Some(colon_pos) = rest.find(": ") { + let cmd = rest[..colon_pos].trim(); + if !cmd.is_empty() { + return cmd.to_string(); + } + } + } + } + "(remote command)".to_string() +} + +fn extract_bash_command(text: &str) -> Option { + let marker = "```bash"; + let start = text.find(marker)?; + let content_start = start + marker.len(); + let content_end = text[content_start..].find("```")?; + let cmd = text[content_start..content_start + content_end].trim().to_string(); + if cmd.is_empty() { + None + } else { + Some(cmd) + } +} + #[cfg(test)] mod phase_tests { use super::*; diff --git a/crates/aish-shell/src/renderer.rs b/crates/aish-shell/src/renderer.rs index 576afe7..50bdd5f 100644 --- a/crates/aish-shell/src/renderer.rs +++ b/crates/aish-shell/src/renderer.rs @@ -319,6 +319,11 @@ impl ShellRenderer { } } + /// Return the detected terminal width (columns). + pub fn width(&self) -> usize { + self.terminal_width + } + /// Render complete markdown text. /// Code blocks → syntax highlighting, tables → box drawing, rest → richrs Markdown. pub fn render_markdown(&mut self, text: &str) { diff --git a/crates/aish-tools/src/channel_ask_user.rs b/crates/aish-tools/src/channel_ask_user.rs new file mode 100644 index 0000000..cb1b633 --- /dev/null +++ b/crates/aish-tools/src/channel_ask_user.rs @@ -0,0 +1,194 @@ +//! Channel-based ask_user tool for SSH sessions. +//! +//! Instead of using `inquire` (which requires direct terminal control), this +//! tool communicates with the forwarding loop via channels. When the LLM +//! calls `ask_user`, the tool sends the question through a channel and blocks +//! until the forwarding loop provides the user's answer. + +use aish_llm::{Tool, ToolResult}; +use aish_pty::{AskUserAnswer, AskUserOption, AskUserRequest, AiEvent}; + +/// Shared translated description — same as AskUserTool. +static DESCRIPTION: std::sync::OnceLock = std::sync::OnceLock::new(); + +fn get_description() -> &'static str { + DESCRIPTION.get_or_init(|| aish_i18n::t("tools.ask_user.description")) +} + +pub struct ChannelAskUserTool { + question_sender: std::sync::mpsc::Sender, + answer_receiver: std::sync::Mutex>, +} + +impl ChannelAskUserTool { + pub fn new( + question_sender: std::sync::mpsc::Sender, + answer_receiver: std::sync::mpsc::Receiver, + ) -> Self { + Self { + question_sender, + answer_receiver: std::sync::Mutex::new(answer_receiver), + } + } + + fn send_and_wait(&self, request: AskUserRequest) -> ToolResult { + if self + .question_sender + .send(AiEvent::AskUser(request)) + .is_err() + { + return ToolResult::error("Channel closed"); + } + + match self.answer_receiver.lock().unwrap().recv() { + Ok(AskUserAnswer::Response(answer)) => { + // Match local AskUserTool: prefix with "用户输入: " via i18n + let mut args_map = std::collections::HashMap::new(); + args_map.insert("input".to_string(), answer); + ToolResult::success(aish_i18n::t_with_args( + "tools.ask_user.user_input_prefix", + &args_map, + )) + } + Ok(AskUserAnswer::Cancelled) => { + // Match local AskUserTool: cancelled returns success, not error + ToolResult::success(aish_i18n::t("tools.ask_user.cancelled")) + } + Err(_) => ToolResult::error("Channel closed"), + } + } +} + +impl Tool for ChannelAskUserTool { + fn name(&self) -> &str { + "ask_user" + } + + fn description(&self) -> &str { + get_description() + } + + fn parameters(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "kind": { + "type": "string", + "enum": ["text_input", "choice_or_text"], + "description": "Interaction type: text_input for free-form, choice_or_text for options with custom input" + }, + "prompt": { + "type": "string", + "description": "The question to ask the user" + }, + "options": { + "type": "array", + "description": "Predefined options for choice_or_text", + "items": { + "type": "object", + "properties": { + "value": {"type": "string"}, + "label": {"type": "string"}, + "description": {"type": "string"} + }, + "required": ["value", "label"] + } + }, + "title": { + "type": "string", + "description": "Optional title for the question" + }, + "default": { + "type": "string", + "description": "Default value" + }, + "allow_cancel": { + "type": "boolean", + "description": "Whether the user can cancel/skip (default: true)", + "default": true + }, + "min_length": { + "type": "integer", + "description": "Minimum length for text input (default: 0)", + "default": 0 + } + }, + "required": ["kind", "prompt"] + }) + } + + fn execute(&self, args: serde_json::Value) -> ToolResult { + let kind = args + .get("kind") + .and_then(|v| v.as_str()) + .unwrap_or("text_input") + .to_string(); + let prompt = match args.get("prompt").and_then(|v| v.as_str()) { + Some(p) => p.to_string(), + None => return ToolResult::error(aish_i18n::t("tools.ask_user.missing_prompt")), + }; + let title = args.get("title").and_then(|v| v.as_str()).map(|s| s.to_string()); + let default = args + .get("default") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let allow_cancel = args + .get("allow_cancel") + .and_then(|v| v.as_bool()) + .unwrap_or(true); + let min_length = args.get("min_length").and_then(|v| v.as_u64()).unwrap_or(0) as usize; + + match kind.as_str() { + "choice_or_text" => { + let options = match args.get("options").and_then(|v| v.as_array()) { + Some(opts) if !opts.is_empty() => opts, + _ => return ToolResult::error(aish_i18n::t("tools.ask_user.options_not_empty")), + }; + let parsed_options: Vec = options + .iter() + .filter_map(|item| { + let value = item.get("value").and_then(|v| v.as_str())?; + let label = item.get("label").and_then(|v| v.as_str()).unwrap_or(value); + let description = item.get("description").and_then(|v| v.as_str()).map(|s| s.to_string()); + Some(AskUserOption { + value: value.to_string(), + label: label.to_string(), + description, + }) + }) + .collect(); + + let request = AskUserRequest { + kind, + prompt, + options: Some(parsed_options), + title, + default, + allow_cancel, + min_length, + }; + self.send_and_wait(request) + } + "text_input" => { + let request = AskUserRequest { + kind, + prompt, + options: None, + title, + default, + allow_cancel, + min_length, + }; + self.send_and_wait(request) + } + _ => { + let mut args_map = std::collections::HashMap::new(); + args_map.insert("kind".to_string(), kind); + ToolResult::error(aish_i18n::t_with_args( + "tools.ask_user.unknown_kind", + &args_map, + )) + } + } + } +} diff --git a/crates/aish-tools/src/channel_bash.rs b/crates/aish-tools/src/channel_bash.rs new file mode 100644 index 0000000..ce7ce8b --- /dev/null +++ b/crates/aish-tools/src/channel_bash.rs @@ -0,0 +1,83 @@ +//! Channel-based bash tool for SSH sessions. +//! +//! When the LLM calls bash_exec in an SSH session, this tool sends the command +//! through a channel to the forwarding loop. The forwarding loop executes it on +//! the remote host and returns the output through a response channel. + +use aish_llm::{Tool, ToolResult}; +use aish_pty::AiEvent; + +static DESCRIPTION: std::sync::OnceLock = std::sync::OnceLock::new(); + +fn get_description() -> &'static str { + DESCRIPTION.get_or_init(|| aish_i18n::t("tools.bash.description")) +} + +pub struct ChannelBashTool { + event_sender: std::sync::mpsc::Sender, +} + +impl ChannelBashTool { + pub fn new(event_sender: std::sync::mpsc::Sender) -> Self { + Self { event_sender } + } +} + +impl Tool for ChannelBashTool { + fn name(&self) -> &str { + "bash" + } + + fn description(&self) -> &str { + get_description() + } + + fn parameters(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": aish_i18n::t("tools.bash.param.command") + }, + "timeout": { + "type": "integer", + "description": aish_i18n::t("tools.bash.param.timeout"), + "default": 120 + } + }, + "required": ["command"] + }) + } + + fn execute(&self, args: serde_json::Value) -> ToolResult { + let command = match args.get("command").and_then(|v| v.as_str()) { + Some(c) => c.to_string(), + None => return ToolResult::error(aish_i18n::t("tools.bash.missing_command")), + }; + let timeout_secs = args.get("timeout").and_then(|v| v.as_u64()).unwrap_or(120); + + let (output_tx, output_rx) = std::sync::mpsc::channel::(); + + if self + .event_sender + .send(AiEvent::BashExec { + command: command.clone(), + output_sender: output_tx, + }) + .is_err() + { + return ToolResult::error("Channel closed"); + } + + match output_rx.recv_timeout(std::time::Duration::from_secs(timeout_secs)) { + Ok(output) => ToolResult::success(output), + Err(std::sync::mpsc::RecvTimeoutError::Timeout) => { + ToolResult::error(aish_i18n::t("tools.bash.execute_failed").replace("{error}", "timeout")) + } + Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => { + ToolResult::error("Channel closed") + } + } + } +} diff --git a/crates/aish-tools/src/lib.rs b/crates/aish-tools/src/lib.rs index d772e6f..582b626 100644 --- a/crates/aish-tools/src/lib.rs +++ b/crates/aish-tools/src/lib.rs @@ -15,6 +15,8 @@ pub mod ask_user; pub mod bash; +pub mod channel_ask_user; +pub mod channel_bash; pub mod final_answer; pub mod fs; pub mod glob_tool; @@ -28,6 +30,8 @@ pub mod skill_tool; pub mod system_diagnose; pub use ask_user::AskUserTool; +pub use channel_ask_user::ChannelAskUserTool; +pub use channel_bash::ChannelBashTool; pub use final_answer::FinalAnswerTool; pub use fs::{EditFileTool, ReadFileTool, WriteFileTool}; pub use glob_tool::GlobTool;