diff --git a/agent_core/__init__.py b/agent_core/__init__.py index ee6a3fd6..d0757090 100644 --- a/agent_core/__init__.py +++ b/agent_core/__init__.py @@ -75,6 +75,7 @@ get_credentials, has_embedded_credentials, run_oauth_flow, + run_oauth_flow_async, ) from agent_core.core.config import ( ConfigRegistry, @@ -312,6 +313,7 @@ "get_credentials", "has_embedded_credentials", "run_oauth_flow", + "run_oauth_flow_async", # Config "ConfigRegistry", "get_workspace_root", diff --git a/agent_core/core/credentials/__init__.py b/agent_core/core/credentials/__init__.py index 39200ffc..055a6c77 100644 --- a/agent_core/core/credentials/__init__.py +++ b/agent_core/core/credentials/__init__.py @@ -8,7 +8,7 @@ encode_credential, generate_credentials_block, ) -from agent_core.core.credentials.oauth_server import run_oauth_flow +from agent_core.core.credentials.oauth_server import run_oauth_flow, run_oauth_flow_async __all__ = [ "get_credential", @@ -17,4 +17,5 @@ "encode_credential", "generate_credentials_block", "run_oauth_flow", + "run_oauth_flow_async", ] diff --git a/agent_core/core/credentials/oauth_server.py b/agent_core/core/credentials/oauth_server.py index ac9f4770..9d8a701f 100644 --- a/agent_core/core/credentials/oauth_server.py +++ b/agent_core/core/credentials/oauth_server.py @@ -16,8 +16,12 @@ # HTTPS (for Slack and other providers requiring https redirect URIs) code, error = run_oauth_flow("https://slack.com/oauth/...", use_https=True) + + # Async version with cancellation support (recommended for UI contexts) + code, error = await run_oauth_flow_async("https://provider.com/oauth/...") """ +import asyncio import ipaddress import logging import os @@ -29,7 +33,7 @@ from datetime import datetime, timedelta, timezone from http.server import HTTPServer, BaseHTTPRequestHandler from urllib.parse import urlparse, parse_qs -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple logger = logging.getLogger(__name__) @@ -104,58 +108,78 @@ def _cleanup_files(*paths: str) -> None: pass -class _OAuthCallbackHandler(BaseHTTPRequestHandler): - """Handler for OAuth callback requests.""" - - code: Optional[str] = None - state: Optional[str] = None - error: Optional[str] = None - - def do_GET(self): - """Handle GET request from OAuth callback.""" - params = parse_qs(urlparse(self.path).query) - _OAuthCallbackHandler.code = params.get("code", [None])[0] - _OAuthCallbackHandler.state = params.get("state", [None])[0] - _OAuthCallbackHandler.error = params.get("error", [None])[0] +def _make_callback_handler(result_holder: Dict[str, Any]): + """ + Create a callback handler class that stores results in the provided dict. - self.send_response(200) - self.send_header("Content-Type", "text/html") - self.end_headers() - if _OAuthCallbackHandler.code: - self.wfile.write( - b"
You can close this tab.
" - ) - else: - self.wfile.write( - f"{_OAuthCallbackHandler.error}
".encode() - ) + This avoids class-level state that would be shared across OAuth flows. + """ + class _OAuthCallbackHandler(BaseHTTPRequestHandler): + """Handler for OAuth callback requests.""" + + def do_GET(self): + """Handle GET request from OAuth callback.""" + params = parse_qs(urlparse(self.path).query) + result_holder["code"] = params.get("code", [None])[0] + result_holder["state"] = params.get("state", [None])[0] + result_holder["error"] = params.get("error", [None])[0] + + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + if result_holder["code"]: + self.wfile.write( + b"You can close this tab.
" + ) + else: + self.wfile.write( + f"{result_holder['error']}
".encode() + ) + + def log_message(self, format, *args): + """Suppress default HTTP server logging.""" + pass - def log_message(self, format, *args): - """Suppress default HTTP server logging.""" - pass + return _OAuthCallbackHandler -def _serve_until_code(server: HTTPServer, deadline: float) -> None: +def _serve_until_code( + server: HTTPServer, + deadline: float, + result_holder: Dict[str, Any], + cancel_event: Optional[threading.Event] = None, +) -> None: """ - Handle requests in a loop until we capture the OAuth code/error or timeout. + Handle requests in a loop until we capture the OAuth code/error, timeout, or cancelled. A single handle_request() can be consumed by TLS handshake failures, favicon requests, browser pre-connects, etc. Looping ensures the server stays alive for the actual callback. """ while time.time() < deadline: - remaining = max(0.5, deadline - time.time()) - server.timeout = min(remaining, 2.0) + # Check for cancellation + if cancel_event and cancel_event.is_set(): + logger.debug("[OAUTH] Cancellation requested, stopping server") + break + + remaining = max(0.1, deadline - time.time()) + # Use shorter timeout (0.5s) for responsive cancellation checking + server.timeout = min(remaining, 0.5) try: server.handle_request() except Exception as e: logger.debug(f"[OAUTH] handle_request error (will retry): {e}") - if _OAuthCallbackHandler.code or _OAuthCallbackHandler.error: + + if result_holder.get("code") or result_holder.get("error"): break def run_oauth_flow( - auth_url: str, port: int = 8765, timeout: int = 120, use_https: bool = False + auth_url: str, + port: int = 8765, + timeout: int = 120, + use_https: bool = False, + cancel_event: Optional[threading.Event] = None, ) -> Tuple[Optional[str], Optional[str]]: """ Open browser for OAuth, wait for callback. @@ -167,17 +191,27 @@ def run_oauth_flow( use_https: If True, serve HTTPS with a self-signed cert. Required for providers like Slack that reject http:// redirect URIs. Default False (plain HTTP — works with Google, Notion, etc.). + cancel_event: Optional threading.Event to signal cancellation. + When set, the OAuth flow will stop and return a cancellation error. Returns: Tuple of (code, error_message): - On success: (authorization_code, None) - On failure: (None, error_message) """ - _OAuthCallbackHandler.code = None - _OAuthCallbackHandler.state = None - _OAuthCallbackHandler.error = None + # Check for early cancellation + if cancel_event and cancel_event.is_set(): + return None, "OAuth cancelled" - server = HTTPServer(("127.0.0.1", port), _OAuthCallbackHandler) + # Use instance-level result holder instead of class-level state + result_holder: Dict[str, Any] = {"code": None, "state": None, "error": None} + handler_class = _make_callback_handler(result_holder) + + try: + server = HTTPServer(("127.0.0.1", port), handler_class) + except OSError as e: + # Port already in use + return None, f"Failed to start OAuth server: {e}" if use_https: cert_path = key_path = None @@ -198,21 +232,85 @@ def run_oauth_flow( deadline = time.time() + timeout thread = threading.Thread( - target=_serve_until_code, args=(server, deadline), daemon=True + target=_serve_until_code, + args=(server, deadline, result_holder, cancel_event), + daemon=True ) thread.start() + # Check cancellation before opening browser + if cancel_event and cancel_event.is_set(): + server.server_close() + return None, "OAuth cancelled" + try: webbrowser.open(auth_url) except Exception: server.server_close() return None, f"Could not open browser. Visit manually:\n{auth_url}" - thread.join(timeout=timeout) + # Wait for thread with periodic cancellation checks + while thread.is_alive(): + thread.join(timeout=0.5) + if cancel_event and cancel_event.is_set(): + logger.debug("[OAUTH] Cancellation detected during wait") + break + server.server_close() - if _OAuthCallbackHandler.error: - return None, _OAuthCallbackHandler.error - if _OAuthCallbackHandler.code: - return _OAuthCallbackHandler.code, None + # Check cancellation first + if cancel_event and cancel_event.is_set(): + return None, "OAuth cancelled" + + if result_holder.get("error"): + return None, result_holder["error"] + if result_holder.get("code"): + return result_holder["code"], None return None, "OAuth timed out." + + +async def run_oauth_flow_async( + auth_url: str, + port: int = 8765, + timeout: int = 120, + use_https: bool = False, +) -> Tuple[Optional[str], Optional[str]]: + """ + Async version of run_oauth_flow with proper cancellation support. + + This function runs the OAuth flow in a thread executor and properly handles + asyncio task cancellation by signaling the OAuth server to stop. + + Args: + auth_url: The full OAuth authorization URL to open. + port: Local port for callback server (default: 8765). + timeout: Seconds to wait for callback (default: 120). + use_https: If True, serve HTTPS with a self-signed cert. + + Returns: + Tuple of (code, error_message): + - On success: (authorization_code, None) + - On failure: (None, error_message) + + Raises: + asyncio.CancelledError: If the task is cancelled (after signaling OAuth to stop) + """ + cancel_event = threading.Event() + loop = asyncio.get_event_loop() + + def run_flow(): + return run_oauth_flow( + auth_url=auth_url, + port=port, + timeout=timeout, + use_https=use_https, + cancel_event=cancel_event, + ) + + try: + return await loop.run_in_executor(None, run_flow) + except asyncio.CancelledError: + # Signal the OAuth server to stop + cancel_event.set() + logger.debug("[OAUTH] Async task cancelled, signaled OAuth server to stop") + raise diff --git a/app/credentials/handlers.py b/app/credentials/handlers.py index c8eceb31..1924910f 100644 --- a/app/credentials/handlers.py +++ b/app/credentials/handlers.py @@ -76,8 +76,8 @@ async def login(self, args): "code_challenge": code_challenge, "code_challenge_method": "S256", } - from agent_core import run_oauth_flow - code, error = run_oauth_flow(f"https://accounts.google.com/o/oauth2/v2/auth?{urlencode(params)}") + from agent_core import run_oauth_flow_async + code, error = await run_oauth_flow_async(f"https://accounts.google.com/o/oauth2/v2/auth?{urlencode(params)}") if error: return False, f"Google OAuth failed: {error}" token_data = { @@ -141,8 +141,8 @@ async def invite(self, args): scopes = "chat:write,channels:read,channels:history,groups:read,groups:history,users:read,files:write,im:read,im:write,im:history" params = {"client_id": SLACK_SHARED_CLIENT_ID, "scope": scopes, "redirect_uri": REDIRECT_URI_HTTPS, "state": secrets.token_urlsafe(32)} - from agent_core import run_oauth_flow - code, error = run_oauth_flow(f"https://slack.com/oauth/v2/authorize?{urlencode(params)}", use_https=True) + from agent_core import run_oauth_flow_async + code, error = await run_oauth_flow_async(f"https://slack.com/oauth/v2/authorize?{urlencode(params)}", use_https=True) if error: return False, f"Slack OAuth failed: {error}" import aiohttp @@ -206,8 +206,8 @@ async def invite(self, args): return False, "CraftOS Notion integration not configured. Set NOTION_SHARED_CLIENT_ID and NOTION_SHARED_CLIENT_SECRET env vars.\nAlternatively, use /notion login