Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/aish/llm/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,8 +1774,14 @@ async def process_input(
else:
msg = response["choices"][0]["message"] # type: ignore
finish_reason = response["choices"][0]["finish_reason"] # type: ignore
except TimeoutError:
events.emit_cancelled("llm_timeout")
except TimeoutError as err:
if (
self.cancellation_token
and self.cancellation_token.is_cancelled()
):
events.emit_cancelled("llm_cancelled")
events.emit_generation_end(status="cancelled")
raise anyio.get_cancelled_exc_class() from err
events.emit_generation_end(status="timeout")
output = "LLM request timed out"
break
Expand Down Expand Up @@ -2010,9 +2016,12 @@ async def completion(
events.emit_cancelled("llm_cancelled")
events.emit_generation_end(status="cancelled")
raise
except TimeoutError:
except TimeoutError as err:
if self.cancellation_token and self.cancellation_token.is_cancelled():
events.emit_cancelled("llm_cancelled")
events.emit_generation_end(status="cancelled")
raise anyio.get_cancelled_exc_class() from err
result = "LLM request timed out"
events.emit_cancelled("llm_timeout")
events.emit_generation_end(status="timeout")
except Exception as e:
if isinstance(e, Exception) and is_litellm_exception(e):
Expand Down
30 changes: 28 additions & 2 deletions src/aish/shell/runtime/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def _shutdown_loop(loop: asyncio.AbstractEventLoop) -> None:
def _run_async_in_thread(coro, cancellation_token=None) -> Any:
"""Run an async coroutine in a separate thread with its own event loop.

Uses polling-based cancellation to allow Ctrl+C interruption.
Uses polling-based cancellation to allow Ctrl+C interruption, and
forwards cancellation into the event loop so in-flight HTTP requests
do not keep running in the background until their timeout expires.
"""
from concurrent.futures import (
ThreadPoolExecutor,
Expand All @@ -139,27 +141,51 @@ def _run_async_in_thread(coro, cancellation_token=None) -> Any:

result_box: list[Optional[str]] = [None]
exc_box: list[BaseException | None] = [None]
loop_ready = threading.Event()
loop_box: list[asyncio.AbstractEventLoop | None] = [None]
task_box: list[asyncio.Task | None] = [None]

def run_in_thread() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop_box[0] = loop
try:
result_box[0] = loop.run_until_complete(coro)
task = loop.create_task(coro)
task_box[0] = task
loop_ready.set()
result_box[0] = loop.run_until_complete(task)
except BaseException as e:
exc_box[0] = e
finally:
loop_ready.set()
AIHandler._shutdown_loop(loop)
loop.close()
loop_box[0] = None
task_box[0] = None

pool = ThreadPoolExecutor(max_workers=1)
future = pool.submit(run_in_thread)

def _cancel_running_task() -> None:
loop = loop_box[0]
task = task_box[0]
if loop is None or task is None or task.done() or loop.is_closed():
return
loop.call_soon_threadsafe(task.cancel)

Comment on lines +169 to +175
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Ensure cancellation is actually delivered before returning cancellation to caller.

On Line 174/Line 186–189, _run_async_in_thread can raise KeyboardInterrupt even if task cancellation was never delivered (loop not ready / future still running), and call_soon_threadsafe can race with loop close and raise RuntimeError. That can leave the coroutine running in the background after the caller thinks it was cancelled.

Suggested hardening
-        def _cancel_running_task() -> None:
+        def _cancel_running_task() -> bool:
             loop = loop_box[0]
             task = task_box[0]
             if loop is None or task is None or task.done() or loop.is_closed():
-                return
-            loop.call_soon_threadsafe(task.cancel)
+                return False
+            try:
+                loop.call_soon_threadsafe(task.cancel)
+                return True
+            except RuntimeError:
+                # loop closed between check and scheduling
+                return False
...
                     if cancellation_token and cancellation_token.is_cancelled():
-                        if loop_ready.wait(timeout=0.2):
-                            _cancel_running_task()
+                        cancel_sent = False
+                        if loop_ready.wait(timeout=0.2):
+                            cancel_sent = _cancel_running_task()
                         try:
                             future.result(timeout=2.0)
                         except FutureTimeoutError:
-                            pass
+                            if cancel_sent and not future.done():
+                                continue
                         raise KeyboardInterrupt("AI operation cancelled by user")

Also applies to: 183-189

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/aish/shell/runtime/ai.py` around lines 169 - 175, The
_cancel_running_task() helper currently calls
loop.call_soon_threadsafe(task.cancel) but may return before cancellation is
delivered and can raise RuntimeError if the loop closes; update
_cancel_running_task() and the _run_async_in_thread cancellation path so you:
(1) guard call_soon_threadsafe with loop.is_running() and catch RuntimeError,
(2) use asyncio.run_coroutine_threadsafe or schedule a small coroutine on the
target loop that cancels the task and awaits task completion, and (3) block the
caller (with a bounded timeout) until task.done() or the awaitable from
run_coroutine_threadsafe finishes, then only return/raise KeyboardInterrupt
after confirming task.cancelled() or task.done(); apply this logic to the
cancellation flows in _cancel_running_task() and the cancellation handling
inside _run_async_in_thread to avoid leaving the coroutine running in
background.

try:
while not future.done():
try:
future.result(timeout=0.2)
except FutureTimeoutError:
# Check if cancellation was requested
if cancellation_token and cancellation_token.is_cancelled():
if loop_ready.wait(timeout=0.2):
_cancel_running_task()
try:
future.result(timeout=2.0)
except FutureTimeoutError:
pass
raise KeyboardInterrupt("AI operation cancelled by user")
finally:
pool.shutdown(wait=False)
Expand Down
76 changes: 76 additions & 0 deletions tests/llm/test_llm_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,79 @@ async def fake_acompletion(**kwargs):
text = str(details)
assert "THIS_SHOULD_NOT_LEAK" not in text
assert "sk-THIS_SHOULD_NOT_LEAK" not in text


@pytest.mark.anyio
async def test_process_input_timeout_is_not_reported_as_cancellation():
config = ConfigModel(model="test-model", api_key="test-key")
session = LLMSession(config=config, skill_manager=SkillManager())

events = []

def event_callback(event):
events.append(event)
return LLMCallbackResult.CONTINUE

session.event_callback = event_callback

async def fake_acompletion(**kwargs):
raise TimeoutError("request timed out")

context_manager = ContextManager()

with (
patch.object(session, "_get_acompletion", return_value=fake_acompletion),
patch.object(session, "_trim_messages", side_effect=lambda msgs: msgs),
patch.object(session, "_get_tools_spec", return_value=[]),
):
result = await session.process_input(
prompt="hi",
context_manager=context_manager,
system_message="sys",
)

assert result == "LLM request timed out"
event_types = [event.event_type for event in events]
assert event_types == [
LLMEventType.OP_START,
LLMEventType.GENERATION_START,
LLMEventType.GENERATION_END,
LLMEventType.OP_END,
]
assert events[2].data.get("status") == "timeout"
assert events[-1].data.get("cancelled") is False
assert events[-1].data.get("cancelled_reason") is None


@pytest.mark.anyio
async def test_completion_timeout_is_not_reported_as_cancellation():
config = ConfigModel(model="test-model", api_key="test-key")
session = LLMSession(config=config, skill_manager=SkillManager())

events = []

def event_callback(event):
events.append(event)
return LLMCallbackResult.CONTINUE

session.event_callback = event_callback

async def fake_acompletion(**kwargs):
raise TimeoutError("request timed out")

with patch.object(session, "_get_acompletion", return_value=fake_acompletion):
result = await session.completion(
prompt="hi", system_message="sys", stream=False
)

assert result == "LLM request timed out"
event_types = [event.event_type for event in events]
assert event_types == [
LLMEventType.OP_START,
LLMEventType.GENERATION_START,
LLMEventType.GENERATION_END,
LLMEventType.OP_END,
]
assert events[2].data.get("status") == "timeout"
assert events[-1].data.get("cancelled") is False
assert events[-1].data.get("cancelled_reason") is None
35 changes: 35 additions & 0 deletions tests/shell/runtime/test_shell_pty_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from __future__ import annotations

import os
import asyncio
import threading
from types import SimpleNamespace

import pytest
from unittest.mock import Mock
from unittest.mock import call

from aish.state.cancellation import CancellationToken
from aish.memory.config import MemoryConfig
from aish.memory.models import MemoryCategory
from aish.i18n import t
Expand Down Expand Up @@ -176,6 +178,39 @@ def test_ai_handler_marks_cancelled_operation_and_notifies_shell():
shell.handle_processing_cancelled.assert_called_once_with()


@pytest.mark.timeout(5)
def test_run_async_in_thread_cancels_running_coroutine():
token = CancellationToken()
started = threading.Event()
cleaned_up = threading.Event()
error_box: list[BaseException | None] = [None]

async def _hang_until_cancelled():
started.set()
try:
while True:
await asyncio.sleep(1)
finally:
cleaned_up.set()

def _run() -> None:
try:
AIHandler._run_async_in_thread(_hang_until_cancelled(), token)
except BaseException as exc:
error_box[0] = exc

thread = threading.Thread(target=_run)
thread.start()

assert started.wait(timeout=1)
token.cancel()
thread.join(timeout=3)

assert not thread.is_alive()
assert isinstance(error_box[0], KeyboardInterrupt)
assert cleaned_up.wait(timeout=1)


def test_ai_handler_auto_retain_persists_explicit_fact():
handler, shell = _make_ai_handler()
shell.memory_manager = Mock()
Expand Down
Loading