diff --git a/src/aish/llm/session.py b/src/aish/llm/session.py index 0c37a1b..799046d 100644 --- a/src/aish/llm/session.py +++ b/src/aish/llm/session.py @@ -1774,8 +1774,14 @@ async def process_input( else: msg = response["choices"][0]["message"] # type: ignore finish_reason = response["choices"][0]["finish_reason"] # type: ignore - except TimeoutError: - events.emit_cancelled("llm_timeout") + except TimeoutError as err: + if ( + self.cancellation_token + and self.cancellation_token.is_cancelled() + ): + events.emit_cancelled("llm_cancelled") + events.emit_generation_end(status="cancelled") + raise anyio.get_cancelled_exc_class() from err events.emit_generation_end(status="timeout") output = "LLM request timed out" break @@ -2010,9 +2016,12 @@ async def completion( events.emit_cancelled("llm_cancelled") events.emit_generation_end(status="cancelled") raise - except TimeoutError: + except TimeoutError as err: + if self.cancellation_token and self.cancellation_token.is_cancelled(): + events.emit_cancelled("llm_cancelled") + events.emit_generation_end(status="cancelled") + raise anyio.get_cancelled_exc_class() from err result = "LLM request timed out" - events.emit_cancelled("llm_timeout") events.emit_generation_end(status="timeout") except Exception as e: if isinstance(e, Exception) and is_litellm_exception(e): diff --git a/src/aish/shell/runtime/ai.py b/src/aish/shell/runtime/ai.py index 0fd3cdc..a898549 100644 --- a/src/aish/shell/runtime/ai.py +++ b/src/aish/shell/runtime/ai.py @@ -130,7 +130,9 @@ def _shutdown_loop(loop: asyncio.AbstractEventLoop) -> None: def _run_async_in_thread(coro, cancellation_token=None) -> Any: """Run an async coroutine in a separate thread with its own event loop. - Uses polling-based cancellation to allow Ctrl+C interruption. + Uses polling-based cancellation to allow Ctrl+C interruption, and + forwards cancellation into the event loop so in-flight HTTP requests + do not keep running in the background until their timeout expires. """ from concurrent.futures import ( ThreadPoolExecutor, @@ -139,20 +141,38 @@ def _run_async_in_thread(coro, cancellation_token=None) -> Any: result_box: list[Optional[str]] = [None] exc_box: list[BaseException | None] = [None] + loop_ready = threading.Event() + loop_box: list[asyncio.AbstractEventLoop | None] = [None] + task_box: list[asyncio.Task | None] = [None] def run_in_thread() -> None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + loop_box[0] = loop try: - result_box[0] = loop.run_until_complete(coro) + task = loop.create_task(coro) + task_box[0] = task + loop_ready.set() + result_box[0] = loop.run_until_complete(task) except BaseException as e: exc_box[0] = e finally: + loop_ready.set() AIHandler._shutdown_loop(loop) loop.close() + loop_box[0] = None + task_box[0] = None pool = ThreadPoolExecutor(max_workers=1) future = pool.submit(run_in_thread) + + def _cancel_running_task() -> None: + loop = loop_box[0] + task = task_box[0] + if loop is None or task is None or task.done() or loop.is_closed(): + return + loop.call_soon_threadsafe(task.cancel) + try: while not future.done(): try: @@ -160,6 +180,12 @@ def run_in_thread() -> None: except FutureTimeoutError: # Check if cancellation was requested if cancellation_token and cancellation_token.is_cancelled(): + if loop_ready.wait(timeout=0.2): + _cancel_running_task() + try: + future.result(timeout=2.0) + except FutureTimeoutError: + pass raise KeyboardInterrupt("AI operation cancelled by user") finally: pool.shutdown(wait=False) diff --git a/tests/llm/test_llm_events.py b/tests/llm/test_llm_events.py index 17a9f88..f9aa417 100644 --- a/tests/llm/test_llm_events.py +++ b/tests/llm/test_llm_events.py @@ -242,3 +242,79 @@ async def fake_acompletion(**kwargs): text = str(details) assert "THIS_SHOULD_NOT_LEAK" not in text assert "sk-THIS_SHOULD_NOT_LEAK" not in text + + +@pytest.mark.anyio +async def test_process_input_timeout_is_not_reported_as_cancellation(): + config = ConfigModel(model="test-model", api_key="test-key") + session = LLMSession(config=config, skill_manager=SkillManager()) + + events = [] + + def event_callback(event): + events.append(event) + return LLMCallbackResult.CONTINUE + + session.event_callback = event_callback + + async def fake_acompletion(**kwargs): + raise TimeoutError("request timed out") + + context_manager = ContextManager() + + with ( + patch.object(session, "_get_acompletion", return_value=fake_acompletion), + patch.object(session, "_trim_messages", side_effect=lambda msgs: msgs), + patch.object(session, "_get_tools_spec", return_value=[]), + ): + result = await session.process_input( + prompt="hi", + context_manager=context_manager, + system_message="sys", + ) + + assert result == "LLM request timed out" + event_types = [event.event_type for event in events] + assert event_types == [ + LLMEventType.OP_START, + LLMEventType.GENERATION_START, + LLMEventType.GENERATION_END, + LLMEventType.OP_END, + ] + assert events[2].data.get("status") == "timeout" + assert events[-1].data.get("cancelled") is False + assert events[-1].data.get("cancelled_reason") is None + + +@pytest.mark.anyio +async def test_completion_timeout_is_not_reported_as_cancellation(): + config = ConfigModel(model="test-model", api_key="test-key") + session = LLMSession(config=config, skill_manager=SkillManager()) + + events = [] + + def event_callback(event): + events.append(event) + return LLMCallbackResult.CONTINUE + + session.event_callback = event_callback + + async def fake_acompletion(**kwargs): + raise TimeoutError("request timed out") + + with patch.object(session, "_get_acompletion", return_value=fake_acompletion): + result = await session.completion( + prompt="hi", system_message="sys", stream=False + ) + + assert result == "LLM request timed out" + event_types = [event.event_type for event in events] + assert event_types == [ + LLMEventType.OP_START, + LLMEventType.GENERATION_START, + LLMEventType.GENERATION_END, + LLMEventType.OP_END, + ] + assert events[2].data.get("status") == "timeout" + assert events[-1].data.get("cancelled") is False + assert events[-1].data.get("cancelled_reason") is None diff --git a/tests/shell/runtime/test_shell_pty_core.py b/tests/shell/runtime/test_shell_pty_core.py index 3f3620f..97e7c39 100644 --- a/tests/shell/runtime/test_shell_pty_core.py +++ b/tests/shell/runtime/test_shell_pty_core.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +import asyncio import threading from types import SimpleNamespace @@ -10,6 +11,7 @@ from unittest.mock import Mock from unittest.mock import call +from aish.state.cancellation import CancellationToken from aish.memory.config import MemoryConfig from aish.memory.models import MemoryCategory from aish.i18n import t @@ -176,6 +178,39 @@ def test_ai_handler_marks_cancelled_operation_and_notifies_shell(): shell.handle_processing_cancelled.assert_called_once_with() +@pytest.mark.timeout(5) +def test_run_async_in_thread_cancels_running_coroutine(): + token = CancellationToken() + started = threading.Event() + cleaned_up = threading.Event() + error_box: list[BaseException | None] = [None] + + async def _hang_until_cancelled(): + started.set() + try: + while True: + await asyncio.sleep(1) + finally: + cleaned_up.set() + + def _run() -> None: + try: + AIHandler._run_async_in_thread(_hang_until_cancelled(), token) + except BaseException as exc: + error_box[0] = exc + + thread = threading.Thread(target=_run) + thread.start() + + assert started.wait(timeout=1) + token.cancel() + thread.join(timeout=3) + + assert not thread.is_alive() + assert isinstance(error_box[0], KeyboardInterrupt) + assert cleaned_up.wait(timeout=1) + + def test_ai_handler_auto_retain_persists_explicit_fact(): handler, shell = _make_ai_handler() shell.memory_manager = Mock()