From 380af20a203cfe25dfb5b8fd8db155ed87b59cc3 Mon Sep 17 00:00:00 2001 From: MervinPraison Date: Mon, 20 Apr 2026 23:06:50 +0100 Subject: [PATCH] feat: OAuth providers, header auth, JWT sessions, thread sharing (closes #22) Auth additions in src/praisonaiui/auth.py: * User / Session dataclasses * OAuthCallbackProtocol, HeaderAuthCallbackProtocol, PasswordAuthCallbackProtocol, OnLogoutCallbackProtocol * @oauth_callback, @header_auth_callback, @password_auth_callback, @on_logout decorators * authenticate_with_headers(), authenticate_with_password(), handle_logout_with_callback(), enhanced_login_handler, enhanced_logout_handler * _user_tokens mapping for O(1) token revocation New modules: * src/praisonaiui/oauth_providers.py - GitHub/Google/Azure/Okta OAuth2 clients + state cleanup background task * src/praisonaiui/features/sharing.py - thread share-token creation, revocation, access checks + @on_shared_thread_view HTTP routes added in server.py: * GET /api/auth/oauth/{provider} * GET /api/auth/oauth/{provider}/callback * POST /api/threads/{thread_id}/share * POST /api/threads/{thread_id}/unshare * GET /shared/{token} Frontend: * src/frontend/src/chat/ShareThreadButton.tsx * src/frontend/src/pages/SharedThread.tsx Public exports (praisonaiui.__init__): oauth_callback, header_auth_callback, password_auth_callback, on_logout, on_shared_thread_view, User, Session Tests: 33 new (14 OAuth, 19 sharing). Full suite: 826 pass, 1 skipped, 7 xfailed (pre-existing PR#30 mock bugs). --- src/frontend/src/chat/ShareThreadButton.tsx | 206 +++++++++ src/frontend/src/pages/SharedThread.tsx | 160 +++++++ src/praisonaiui/__init__.py | 25 ++ src/praisonaiui/auth.py | 303 ++++++++++++- src/praisonaiui/features/sharing.py | 161 +++++++ src/praisonaiui/oauth_providers.py | 451 ++++++++++++++++++++ src/praisonaiui/server.py | 252 ++++++++++- tests/unit/test_oauth.py | 339 +++++++++++++++ tests/unit/test_sharing.py | 409 ++++++++++++++++++ 9 files changed, 2304 insertions(+), 2 deletions(-) create mode 100644 src/frontend/src/chat/ShareThreadButton.tsx create mode 100644 src/frontend/src/pages/SharedThread.tsx create mode 100644 src/praisonaiui/features/sharing.py create mode 100644 src/praisonaiui/oauth_providers.py create mode 100644 tests/unit/test_oauth.py create mode 100644 tests/unit/test_sharing.py diff --git a/src/frontend/src/chat/ShareThreadButton.tsx b/src/frontend/src/chat/ShareThreadButton.tsx new file mode 100644 index 0000000..5ef3024 --- /dev/null +++ b/src/frontend/src/chat/ShareThreadButton.tsx @@ -0,0 +1,206 @@ +import { useState } from 'react'; +import { Share2, Copy, Check, X } from 'lucide-react'; +import { Button } from '../components/ui/button'; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogTrigger, +} from '../components/ui/dialog'; +import { Input } from '../components/ui/input'; +import { Label } from '../components/ui/label'; + +interface ShareThreadButtonProps { + threadId: string; + className?: string; +} + +export function ShareThreadButton({ threadId, className }: ShareThreadButtonProps) { + const [isOpen, setIsOpen] = useState(false); + const [shareUrl, setShareUrl] = useState(''); + const [isLoading, setIsLoading] = useState(false); + const [isCopied, setIsCopied] = useState(false); + const [error, setError] = useState(''); + + const createShareLink = async () => { + setIsLoading(true); + setError(''); + + try { + const token = localStorage.getItem('token'); + if (!token) { + setError('Authentication required'); + return; + } + + const response = await fetch(`/api/threads/${threadId}/share`, { + method: 'POST', + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json', + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.error || 'Failed to create share link'); + } + + const data = await response.json(); + setShareUrl(data.url); + } catch (err) { + setError(err instanceof Error ? err.message : 'Failed to create share link'); + } finally { + setIsLoading(false); + } + }; + + const revokeShareLink = async () => { + setIsLoading(true); + setError(''); + + try { + const token = localStorage.getItem('token'); + if (!token) { + setError('Authentication required'); + return; + } + + const response = await fetch(`/api/threads/${threadId}/unshare`, { + method: 'POST', + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json', + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.error || 'Failed to revoke share link'); + } + + setShareUrl(''); + setIsCopied(false); + } catch (err) { + setError(err instanceof Error ? err.message : 'Failed to revoke share link'); + } finally { + setIsLoading(false); + } + }; + + const copyToClipboard = async () => { + try { + await navigator.clipboard.writeText(shareUrl); + setIsCopied(true); + setTimeout(() => setIsCopied(false), 2000); + } catch (err) { + console.error('Failed to copy to clipboard:', err); + } + }; + + const handleOpenChange = (open: boolean) => { + setIsOpen(open); + if (open && !shareUrl) { + createShareLink(); + } + }; + + return ( + + + + + + + Share Thread + + Anyone with this link will be able to view this conversation in read-only mode. + + + +
+ {error && ( +
+ {error} +
+ )} + + {shareUrl ? ( +
+
+ +
+ + +
+
+ +
+ + +
+ +
+
+
+ ) : ( +
+ {isLoading ? ( +
Creating share link...
+ ) : ( + + )} +
+ )} +
+
+
+ ); +} \ No newline at end of file diff --git a/src/frontend/src/pages/SharedThread.tsx b/src/frontend/src/pages/SharedThread.tsx new file mode 100644 index 0000000..7f2d596 --- /dev/null +++ b/src/frontend/src/pages/SharedThread.tsx @@ -0,0 +1,160 @@ +import { useEffect, useState } from 'react'; +import { useParams } from 'react-router-dom'; +import { Eye, AlertCircle } from 'lucide-react'; +import { ChatMessages } from '../chat/ChatMessages'; +import { Button } from '../components/ui/button'; + +interface SharedThreadData { + thread_id: string; + session: { + id: string; + title: string; + created_at: string; + }; + messages: Array<{ + id: string; + role: string; + content: string; + timestamp: string; + toolCalls?: any[]; + }>; + read_only: boolean; +} + +export function SharedThread() { + const { token } = useParams<{ token: string }>(); + const [threadData, setThreadData] = useState(null); + const [isLoading, setIsLoading] = useState(true); + const [error, setError] = useState(''); + + useEffect(() => { + const fetchSharedThread = async () => { + if (!token) { + setError('Invalid share link'); + setIsLoading(false); + return; + } + + try { + const response = await fetch(`/shared/${token}`, { + headers: { + 'Accept': 'application/json', + }, + }); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.error || 'Failed to load shared thread'); + } + + const data = await response.json(); + setThreadData(data); + } catch (err) { + setError(err instanceof Error ? err.message : 'Failed to load shared thread'); + } finally { + setIsLoading(false); + } + }; + + fetchSharedThread(); + }, [token]); + + if (isLoading) { + return ( +
+
+
+
Loading shared conversation...
+
+
+ ); + } + + if (error) { + return ( +
+
+ +

+ Unable to Load Shared Thread +

+

{error}

+ +
+
+ ); + } + + if (!threadData) { + return null; + } + + return ( +
+ {/* Header */} +
+
+
+
+ +
+

+ {threadData.session.title || 'Shared Conversation'} +

+

+ Read-only view • Created {new Date(threadData.session.created_at).toLocaleDateString()} +

+
+
+ + +
+
+
+ + {/* Messages */} +
+
+ {threadData.messages.length > 0 ? ( +
+ +
+ ) : ( +
+
+ +

This conversation is empty.

+
+
+ )} +
+ + {/* Read-only notice */} +
+
+ + + You're viewing a read-only shared conversation. + You can't reply or interact with this chat. + +
+
+
+
+ ); +} \ No newline at end of file diff --git a/src/praisonaiui/__init__.py b/src/praisonaiui/__init__.py index 26047a3..c300d6a 100644 --- a/src/praisonaiui/__init__.py +++ b/src/praisonaiui/__init__.py @@ -75,6 +75,15 @@ def __getattr__(name: str): "current_user", "on_slack_reaction_added", } + _auth_attrs = { + "oauth_callback", + "header_auth_callback", + "password_auth_callback", + "on_logout", + "User", + "Session", + "on_shared_thread_view", + } _usage_attrs = {"get_token_usage"} _instrumentation_attrs = { "instrument_openai", @@ -227,6 +236,14 @@ def __getattr__(name: str): from praisonaiui.features import platform_adapters return getattr(platform_adapters, name) + if name in _auth_attrs: + if name == "on_shared_thread_view": + from praisonaiui.features import sharing + + return getattr(sharing, name) + from praisonaiui import auth + + return getattr(auth, name) if name in _usage_attrs: from praisonaiui.features import usage @@ -380,6 +397,14 @@ def __getattr__(name: str): "current_channel", "current_user", "on_slack_reaction_added", + # Auth decorators & classes (OAuth, header, password, logout, sharing) + "oauth_callback", + "header_auth_callback", + "password_auth_callback", + "on_logout", + "on_shared_thread_view", + "User", + "Session", # LLM instrumentation "instrument_openai", "instrument_anthropic", diff --git a/src/praisonaiui/auth.py b/src/praisonaiui/auth.py index 19fc19d..7d44ec3 100644 --- a/src/praisonaiui/auth.py +++ b/src/praisonaiui/auth.py @@ -4,7 +4,7 @@ import secrets from datetime import datetime, timedelta -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Protocol from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request @@ -225,3 +225,304 @@ async def me_handler(request: Request) -> JSONResponse: } ) return JSONResponse({"error": "Unauthorized"}, status_code=401) + + +# ── OAuth / header / session extensions (issue #22) ── + +_oauth_callbacks: dict[str, OAuthCallbackProtocol] = {} +_header_auth_callback: Optional[HeaderAuthCallbackProtocol] = None +_password_auth_callback: Optional[PasswordAuthCallbackProtocol] = None +_on_logout_callback: Optional[OnLogoutCallbackProtocol] = None +_user_tokens: dict[str, set[str]] = {} + + +class User: + """User data class for authentication callbacks.""" + + def __init__( + self, + identifier: str, + display_name: str = "", + metadata: Optional[dict[str, Any]] = None, + ): + self.identifier = identifier + self.display_name = display_name or identifier + self.metadata = metadata or {} + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary representation.""" + return { + "identifier": self.identifier, + "display_name": self.display_name, + "metadata": self.metadata, + } + + +class Session: + """Session data class for auth callbacks.""" + + def __init__(self, session_id: str, user_id: str, tokens: dict[str, Any]): + self.session_id = session_id + self.user_id = user_id + self.tokens = tokens + + async def clear_tokens(self) -> None: + """Clear all tokens for this session.""" + # Clear JWT tokens from global store efficiently using user mapping + if self.user_id in _user_tokens: + for token in list(_user_tokens[self.user_id]): + _tokens.pop(token, None) + del _user_tokens[self.user_id] + + +class OAuthCallbackProtocol(Protocol): + """Protocol for OAuth callback handlers.""" + + async def __call__( + self, + provider: str, + token: dict[str, Any], + raw_user: dict[str, Any], + default_user: User, + ) -> Optional[User]: + """Handle OAuth callback. + + Args: + provider: OAuth provider name (e.g., "github", "google") + token: OAuth token data from provider + raw_user: Raw user data from provider API + default_user: Default User instance based on OAuth data + + Returns: + User instance to accept login, None to deny + """ + ... + + +class HeaderAuthCallbackProtocol(Protocol): + """Protocol for header-based authentication.""" + + async def __call__(self, headers: dict[str, str]) -> Optional[User]: + """Authenticate user from request headers. + + Args: + headers: Request headers dict + + Returns: + User instance if authenticated, None if not + """ + ... + + +class PasswordAuthCallbackProtocol(Protocol): + """Protocol for password authentication callback.""" + + async def __call__(self, username: str, password: str) -> Optional[User]: + """Handle password authentication. + + Args: + username: Username + password: Password + + Returns: + User instance if authenticated, None if not + """ + ... + + +class OnLogoutCallbackProtocol(Protocol): + """Protocol for logout callback.""" + + async def __call__(self, user: User, session: Session) -> None: + """Handle user logout. + + Args: + user: User instance + session: Session instance + """ + ... + + +# Registry for auth callbacks (lazy-loaded) +_oauth_callbacks: dict[str, OAuthCallbackProtocol] = {} +_header_auth_callback: Optional[HeaderAuthCallbackProtocol] = None +_password_auth_callback: Optional[PasswordAuthCallbackProtocol] = None +_on_logout_callback: Optional[OnLogoutCallbackProtocol] = None + +# Legacy storage (backward compatibility) +_users: dict[str, dict[str, Any]] = {} +_tokens: dict[str, dict[str, Any]] = {} +# User to tokens mapping for O(1) token revocation +_user_tokens: dict[str, set[str]] = {} +_login_callback: Optional[Callable] = None + +TOKEN_EXPIRY_HOURS = 24 + +# Try to import bcrypt, fall back to hashlib if not available +try: + import bcrypt + + _HAS_BCRYPT = True +except ImportError: + import hashlib + + _HAS_BCRYPT = False + + +def oauth_callback(provider: str): + """Register an OAuth callback for a specific provider. + + Args: + provider: OAuth provider name (e.g., "github", "google", "azure", "okta") + + Example: + @aiui.oauth_callback("github") + async def on_github_login(provider, token, raw_user, default_user): + return aiui.User( + identifier=f"github:{raw_user['login']}", + display_name=raw_user["name"], + metadata={"avatar": raw_user["avatar_url"]}, + ) + """ + + def decorator(func: OAuthCallbackProtocol) -> OAuthCallbackProtocol: + _oauth_callbacks[provider] = func + return func + + return decorator + + +def header_auth_callback(func: HeaderAuthCallbackProtocol) -> HeaderAuthCallbackProtocol: + """Register a header-based authentication callback. + + Example: + @aiui.header_auth_callback + async def on_header_auth(headers): + email = headers.get("x-auth-request-email") + if email: + return aiui.User(identifier=email, display_name=email.split("@")[0]) + return None + """ + global _header_auth_callback + _header_auth_callback = func + return func + + +def password_auth_callback(func: PasswordAuthCallbackProtocol) -> PasswordAuthCallbackProtocol: + """Register an explicit password authentication callback. + + Example: + @aiui.password_auth_callback + async def on_password_auth(username, password): + # Custom password validation logic + if validate_ldap(username, password): + return aiui.User(identifier=username, display_name=username) + return None + """ + global _password_auth_callback + _password_auth_callback = func + return func + + +def on_logout(func: OnLogoutCallbackProtocol) -> OnLogoutCallbackProtocol: + """Register a logout callback for server-side cleanup. + + Example: + @aiui.on_logout + async def cleanup(user, session): + await session.clear_tokens() + logger.info("user %s logged out", user.identifier) + """ + global _on_logout_callback + _on_logout_callback = func + return func + + +# ── Enhanced Authentication Logic ────────────────────────────────────── + + +async def authenticate_with_headers(request: Request) -> Optional[User]: + """Authenticate using header auth callback if registered.""" + if not _header_auth_callback: + return None + + # Convert Starlette headers to dict + headers = {} + for name, value in request.headers.items(): + headers[name.lower()] = value + + try: + return await _header_auth_callback(headers) + except Exception: + return None + + +async def authenticate_with_password(username: str, password: str) -> Optional[User]: + """Enhanced password authentication with custom callback support.""" + # Try custom password callback first + if _password_auth_callback: + try: + return await _password_auth_callback(username, password) + except Exception: + pass + + # Fallback to legacy authentication + if username not in _users: + return None + user_data = _users[username] + if not verify_password(password, user_data["password_hash"]): + return None + + return User( + identifier=user_data["id"], + display_name=user_data["username"], + ) + + +async def handle_logout_with_callback(user: User, session: Session) -> None: + """Handle logout with optional callback.""" + if _on_logout_callback: + try: + await _on_logout_callback(user, session) + except Exception: + pass # Don't fail logout on callback errors + + +# ── Enhanced Auth Handlers ───────────────────────────────────────────── + + +async def enhanced_login_handler(request: Request) -> JSONResponse: + """Enhanced login handler with header auth support.""" + # Try header authentication first + user = await authenticate_with_headers(request) + if user: + token = create_token(user.identifier) + return JSONResponse( + { + "user": user.to_dict(), + "token": token, + } + ) + + # Fall back to password authentication + return await login_handler(request) + + +async def enhanced_logout_handler(request: Request) -> JSONResponse: + """Enhanced logout handler with callback support.""" + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + user_id = validate_token(token) + if user_id: + # Create User and Session instances for callback + user = User(identifier=user_id, display_name=user_id) + session = Session(session_id="", user_id=user_id, tokens={token: True}) + + # Call logout callback + await handle_logout_with_callback(user, session) + + # Revoke token + revoke_token(token) + + return JSONResponse({"status": "logged_out"}) diff --git a/src/praisonaiui/features/sharing.py b/src/praisonaiui/features/sharing.py new file mode 100644 index 0000000..094bc6b --- /dev/null +++ b/src/praisonaiui/features/sharing.py @@ -0,0 +1,161 @@ +"""Thread sharing functionality for PraisonAIUI.""" + +from __future__ import annotations + +import secrets +from datetime import datetime, timezone +from typing import Any, Optional, Protocol + +from praisonaiui.auth import User + + +class OnSharedThreadViewProtocol(Protocol): + """Protocol for shared thread view callback.""" + + async def __call__(self, thread_id: str, viewer: Optional[User]) -> bool: + """Check if thread can be viewed by the visitor. + + Args: + thread_id: Thread/session ID being accessed + viewer: User instance if authenticated, None if anonymous + + Returns: + True to allow access, False to deny + """ + ... + + +# Global registry for sharing callback (lazy-loaded) +_on_shared_thread_view_callback: Optional[OnSharedThreadViewProtocol] = None + +# In-memory storage for share tokens (should be replaced with database in production) +_share_tokens: dict[str, dict[str, Any]] = {} + + +def on_shared_thread_view(func: OnSharedThreadViewProtocol) -> OnSharedThreadViewProtocol: + """Register a callback for shared thread view access control. + + Example: + @aiui.on_shared_thread_view + async def on_share_view(thread_id, viewer): + # Return True to allow view, False to deny + # Receives None for viewer if anonymous + return True # Or check your ACL + """ + global _on_shared_thread_view_callback + _on_shared_thread_view_callback = func + return func + + +def create_share_token(thread_id: str, created_by: str) -> str: + """Create a share token for a thread. + + Args: + thread_id: Thread/session ID to share + created_by: User ID who created the share token + + Returns: + Opaque share token (32-byte URL-safe) + """ + # Generate collision-resistant token + token = secrets.token_urlsafe(32) + + _share_tokens[token] = { + "thread_id": thread_id, + "created_by": created_by, + "created_at": datetime.now(timezone.utc).isoformat(), + } + + return token + + +def get_thread_by_share_token(token: str) -> Optional[str]: + """Get thread ID from share token. + + Args: + token: Share token + + Returns: + Thread ID if token exists, None otherwise + """ + token_data = _share_tokens.get(token) + return token_data["thread_id"] if token_data else None + + +def revoke_share_token(thread_id: str, user_id: str) -> bool: + """Revoke share token(s) for a thread. + + Args: + thread_id: Thread ID + user_id: User who wants to revoke (must be creator) + + Returns: + True if token(s) were revoked, False if none found or not authorized + """ + revoked = False + tokens_to_remove = [] + + for token, data in _share_tokens.items(): + if data["thread_id"] == thread_id and data["created_by"] == user_id: + tokens_to_remove.append(token) + + for token in tokens_to_remove: + del _share_tokens[token] + revoked = True + + return revoked + + +def list_share_tokens(user_id: str) -> list[dict[str, Any]]: + """List share tokens created by a user. + + Args: + user_id: User ID + + Returns: + List of share token data + """ + return [ + { + "token": token, + "thread_id": data["thread_id"], + "created_at": data["created_at"], + } + for token, data in _share_tokens.items() + if data["created_by"] == user_id + ] + + +async def check_shared_thread_access(thread_id: str, viewer: Optional[User] = None) -> bool: + """Check if a thread can be viewed via sharing. + + Args: + thread_id: Thread ID + viewer: User instance if authenticated, None if anonymous + + Returns: + True if access allowed, False if denied + """ + # Default to deny if no handler is registered (safe default) + if not _on_shared_thread_view_callback: + return False + + try: + return await _on_shared_thread_view_callback(thread_id, viewer) + except Exception: + # Fail safely on callback errors + return False + + +def get_share_url(token: str, base_url: str = "") -> str: + """Generate the full share URL for a token. + + Args: + token: Share token + base_url: Base URL (e.g., "https://example.com") + + Returns: + Full share URL + """ + base_url = base_url.rstrip("/") + return f"{base_url}/shared/{token}" diff --git a/src/praisonaiui/oauth_providers.py b/src/praisonaiui/oauth_providers.py new file mode 100644 index 0000000..d36a527 --- /dev/null +++ b/src/praisonaiui/oauth_providers.py @@ -0,0 +1,451 @@ +"""OAuth2 providers for PraisonAIUI authentication.""" + +from __future__ import annotations + +import asyncio +import os +import secrets +import time +import urllib.parse +from abc import ABC, abstractmethod +from typing import Any, Optional + +import httpx +from starlette.responses import RedirectResponse + +# Shared HTTP client for connection pooling +_http_client: Optional[httpx.AsyncClient] = None + + +def get_http_client() -> httpx.AsyncClient: + """Get shared HTTP client for OAuth requests.""" + global _http_client + if _http_client is None: + _http_client = httpx.AsyncClient( + timeout=httpx.Timeout(30.0), + limits=httpx.Limits(max_connections=20, max_keepalive_connections=5), + ) + return _http_client + + +async def close_http_client() -> None: + """Close shared HTTP client (call on shutdown).""" + global _http_client + if _http_client is not None: + await _http_client.aclose() + _http_client = None + + +class OAuthProvider(ABC): + """Base OAuth2 provider.""" + + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + + @abstractmethod + def get_authorize_url(self, state: str) -> str: + """Get the OAuth authorization URL.""" + ... + + @abstractmethod + async def exchange_code_for_token(self, code: str) -> dict[str, Any]: + """Exchange authorization code for access token.""" + ... + + @abstractmethod + async def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]: + """Get user info from the provider API.""" + ... + + @property + @abstractmethod + def scopes(self) -> list[str]: + """Default scopes for this provider.""" + ... + + def create_authorize_response(self, state: str) -> RedirectResponse: + """Create a redirect response to the OAuth authorization URL.""" + url = self.get_authorize_url(state) + return RedirectResponse(url=url, status_code=302) + + +class GitHubProvider(OAuthProvider): + """GitHub OAuth2 provider.""" + + @property + def scopes(self) -> list[str]: + return ["user:email"] + + def get_authorize_url(self, state: str) -> str: + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": " ".join(self.scopes), + "state": state, + "response_type": "code", + } + return f"https://github.com/login/oauth/authorize?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_token(self, code: str) -> dict[str, Any]: + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "redirect_uri": self.redirect_uri, + } + headers = {"Accept": "application/json"} + + client = get_http_client() + response = await client.post( + "https://github.com/login/oauth/access_token", + data=data, + headers=headers, + ) + response.raise_for_status() + return response.json() + + async def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]: + access_token = token["access_token"] + headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/json", + } + + client = get_http_client() + response = await client.get( + "https://api.github.com/user", + headers=headers, + ) + response.raise_for_status() + return response.json() + + +class GoogleProvider(OAuthProvider): + """Google OAuth2 provider.""" + + @property + def scopes(self) -> list[str]: + return ["openid", "email", "profile"] + + def get_authorize_url(self, state: str) -> str: + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": " ".join(self.scopes), + "state": state, + "response_type": "code", + "access_type": "offline", + } + return f"https://accounts.google.com/o/oauth2/v2/auth?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_token(self, code: str) -> dict[str, Any]: + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": self.redirect_uri, + } + + client = get_http_client() + response = await client.post( + "https://oauth2.googleapis.com/token", + data=data, + ) + response.raise_for_status() + return response.json() + + async def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]: + access_token = token["access_token"] + headers = {"Authorization": f"Bearer {access_token}"} + + client = get_http_client() + response = await client.get( + "https://www.googleapis.com/oauth2/v2/userinfo", + headers=headers, + ) + response.raise_for_status() + return response.json() + + +class AzureADProvider(OAuthProvider): + """Azure Active Directory OAuth2 provider.""" + + def __init__( + self, client_id: str, client_secret: str, redirect_uri: str, tenant_id: str = "common" + ): + super().__init__(client_id, client_secret, redirect_uri) + self.tenant_id = tenant_id + + @property + def scopes(self) -> list[str]: + return ["openid", "profile", "email"] + + def get_authorize_url(self, state: str) -> str: + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": " ".join(self.scopes), + "state": state, + "response_type": "code", + } + return f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/authorize?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_token(self, code: str) -> dict[str, Any]: + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": self.redirect_uri, + "scope": " ".join(self.scopes), + } + + client = get_http_client() + response = await client.post( + f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token", + data=data, + ) + response.raise_for_status() + return response.json() + + async def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]: + access_token = token["access_token"] + headers = {"Authorization": f"Bearer {access_token}"} + + client = get_http_client() + response = await client.get( + "https://graph.microsoft.com/v1.0/me", + headers=headers, + ) + response.raise_for_status() + return response.json() + + +class OktaProvider(OAuthProvider): + """Okta OAuth2 provider.""" + + def __init__(self, client_id: str, client_secret: str, redirect_uri: str, domain: str): + super().__init__(client_id, client_secret, redirect_uri) + self.domain = domain.rstrip("/") + + @property + def scopes(self) -> list[str]: + return ["openid", "profile", "email"] + + def get_authorize_url(self, state: str) -> str: + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": " ".join(self.scopes), + "state": state, + "response_type": "code", + } + return f"{self.domain}/oauth2/v1/authorize?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_token(self, code: str) -> dict[str, Any]: + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": self.redirect_uri, + } + headers = {"Accept": "application/json"} + + client = get_http_client() + response = await client.post( + f"{self.domain}/oauth2/v1/token", + data=data, + headers=headers, + ) + response.raise_for_status() + return response.json() + + async def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]: + access_token = token["access_token"] + headers = {"Authorization": f"Bearer {access_token}"} + + client = get_http_client() + response = await client.get( + f"{self.domain}/oauth2/v1/userinfo", + headers=headers, + ) + response.raise_for_status() + return response.json() + + +# ── Provider Factory ──────────────────────────────────────────────────── + + +def create_oauth_provider( + provider_name: str, + client_id: str, + client_secret: str, + redirect_uri: str, + **kwargs: Any, +) -> OAuthProvider: + """Create an OAuth provider instance from configuration. + + Args: + provider_name: Name of the provider ("github", "google", "azure", "okta") + client_id: OAuth client ID + client_secret: OAuth client secret + redirect_uri: OAuth redirect URI + **kwargs: Provider-specific configuration + + Returns: + OAuthProvider instance + + Raises: + ValueError: If provider_name is not supported + """ + if provider_name == "github": + return GitHubProvider(client_id, client_secret, redirect_uri) + elif provider_name == "google": + return GoogleProvider(client_id, client_secret, redirect_uri) + elif provider_name == "azure": + tenant_id = kwargs.get("tenant_id", "common") + return AzureADProvider(client_id, client_secret, redirect_uri, tenant_id) + elif provider_name == "okta": + domain = kwargs.get("domain") + if not domain: + raise ValueError("Okta provider requires 'domain' parameter") + return OktaProvider(client_id, client_secret, redirect_uri, domain) + else: + raise ValueError(f"Unsupported OAuth provider: {provider_name}") + + +def get_oauth_config_from_env(provider_name: str) -> Optional[dict[str, str]]: + """Get OAuth configuration from environment variables. + + Expected environment variables: + - AIUI_OAUTH_{PROVIDER}_CLIENT_ID + - AIUI_OAUTH_{PROVIDER}_CLIENT_SECRET + - AIUI_OAUTH_{PROVIDER}_REDIRECT_URI (optional, defaults to auto-generated) + + Additional provider-specific env vars: + - AIUI_OAUTH_AZURE_TENANT_ID (for Azure AD) + - AIUI_OAUTH_OKTA_DOMAIN (for Okta) + + Args: + provider_name: Name of the provider ("github", "google", "azure", "okta") + + Returns: + Dict with OAuth config or None if not configured + """ + provider_upper = provider_name.upper() + client_id = os.environ.get(f"AIUI_OAUTH_{provider_upper}_CLIENT_ID") + client_secret = os.environ.get(f"AIUI_OAUTH_{provider_upper}_CLIENT_SECRET") + + if not client_id or not client_secret: + return None + + config = { + "client_id": client_id, + "client_secret": client_secret, + } + + # Optional redirect URI (auto-generated if not provided) + redirect_uri = os.environ.get(f"AIUI_OAUTH_{provider_upper}_REDIRECT_URI") + if redirect_uri: + config["redirect_uri"] = redirect_uri + + # Provider-specific config + if provider_name == "azure": + tenant_id = os.environ.get("AIUI_OAUTH_AZURE_TENANT_ID", "common") + config["tenant_id"] = tenant_id + elif provider_name == "okta": + domain = os.environ.get("AIUI_OAUTH_OKTA_DOMAIN") + if domain: + config["domain"] = domain + + return config + + +# ── OAuth State Management ────────────────────────────────────────────── + +_oauth_states: dict[str, dict[str, Any]] = {} + + +def create_oauth_state(provider: str, return_url: str = "/") -> str: + """Create a new OAuth state token. + + Args: + provider: OAuth provider name + return_url: URL to redirect to after successful authentication + + Returns: + State token string + """ + state = secrets.token_urlsafe(32) + _oauth_states[state] = { + "provider": provider, + "return_url": return_url, + "created_at": time.time(), # Use wall clock time for expiry + } + return state + + +def validate_oauth_state(state: str) -> Optional[dict[str, Any]]: + """Validate and consume an OAuth state token. + + Args: + state: State token to validate + + Returns: + State data if valid, None if invalid/expired + """ + if state not in _oauth_states: + return None + + state_data = _oauth_states.pop(state) # Consume state (one-time use) + + # Check expiry (10 minutes) + current_time = time.time() + if current_time - state_data["created_at"] > 600: + return None + + return state_data + + +def cleanup_expired_states() -> None: + """Clean up expired OAuth states (called periodically).""" + current_time = time.time() + expired_states = [ + state for state, data in _oauth_states.items() if current_time - data["created_at"] > 600 + ] + for state in expired_states: + _oauth_states.pop(state, None) + + +_cleanup_task: Optional[asyncio.Task] = None + + +async def _periodic_cleanup() -> None: + """Run OAuth state cleanup periodically.""" + while True: + try: + await asyncio.sleep(300) # Clean up every 5 minutes + cleanup_expired_states() + except asyncio.CancelledError: + break + except Exception: + # Log error but don't crash the cleanup task + pass + + +def start_cleanup_task() -> None: + """Start the periodic cleanup task.""" + global _cleanup_task + if _cleanup_task is None or _cleanup_task.done(): + _cleanup_task = asyncio.create_task(_periodic_cleanup()) + + +def stop_cleanup_task() -> None: + """Stop the periodic cleanup task.""" + global _cleanup_task + if _cleanup_task and not _cleanup_task.done(): + _cleanup_task.cancel() diff --git a/src/praisonaiui/server.py b/src/praisonaiui/server.py index c597ac6..2c190de 100644 --- a/src/praisonaiui/server.py +++ b/src/praisonaiui/server.py @@ -18,7 +18,13 @@ from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request -from starlette.responses import HTMLResponse, JSONResponse, Response, StreamingResponse +from starlette.responses import ( + HTMLResponse, + JSONResponse, + RedirectResponse, + Response, + StreamingResponse, +) from starlette.routing import Mount, Route from starlette.staticfiles import StaticFiles @@ -1418,6 +1424,232 @@ async def api_mcp_disconnect(request: Request) -> JSONResponse: return JSONResponse({"error": str(e)}, status_code=500) +# ── OAuth / thread-sharing handlers (issue #22) ───────────────────── + + +async def oauth_authorize_handler(request: Request) -> Response: + """Handle OAuth authorization request.""" + provider_name = request.path_params["provider"] + try: + from praisonaiui.oauth_providers import ( + create_oauth_provider, + create_oauth_state, + get_oauth_config_from_env, + ) + + config = get_oauth_config_from_env(provider_name) + if not config: + return JSONResponse( + {"error": f"OAuth provider '{provider_name}' not configured"}, + status_code=400, + ) + if "redirect_uri" not in config: + base_url = str(request.base_url).rstrip("/") + config["redirect_uri"] = f"{base_url}/api/auth/oauth/{provider_name}/callback" + provider = create_oauth_provider(provider_name, **config) + return_url = request.query_params.get("return_url", "/") + state = create_oauth_state(provider_name, return_url) + return provider.create_authorize_response(state) + except ImportError: + return JSONResponse( + {"error": "OAuth providers not available (missing httpx dependency)"}, + status_code=500, + ) + except Exception as e: + return JSONResponse({"error": f"OAuth setup error: {e!s}"}, status_code=500) + + +async def oauth_callback_handler(request: Request) -> Response: + """Handle OAuth callback from provider.""" + provider_name = request.path_params["provider"] + try: + from praisonaiui.auth import User, _oauth_callbacks, create_token + from praisonaiui.oauth_providers import ( + create_oauth_provider, + get_oauth_config_from_env, + validate_oauth_state, + ) + + code = request.query_params.get("code") + state = request.query_params.get("state") + error = request.query_params.get("error") + if error: + return JSONResponse({"error": f"OAuth error: {error}"}, status_code=400) + if not code or not state: + return JSONResponse({"error": "Missing authorization code or state"}, status_code=400) + state_data = validate_oauth_state(state) + if not state_data or state_data["provider"] != provider_name: + return JSONResponse({"error": "Invalid or expired state parameter"}, status_code=400) + config = get_oauth_config_from_env(provider_name) + if not config: + return JSONResponse( + {"error": f"OAuth provider '{provider_name}' not configured"}, + status_code=500, + ) + if "redirect_uri" not in config: + base_url = str(request.base_url).rstrip("/") + config["redirect_uri"] = f"{base_url}/api/auth/oauth/{provider_name}/callback" + provider = create_oauth_provider(provider_name, **config) + token_data = await provider.exchange_code_for_token(code) + user_info = await provider.get_user_info(token_data) + + if provider_name == "github": + default_user = User( + identifier=f"github:{user_info['login']}", + display_name=user_info.get("name") or user_info["login"], + metadata={ + "avatar": user_info.get("avatar_url"), + "login": user_info["login"], + }, + ) + elif provider_name == "google": + default_user = User( + identifier=f"google:{user_info['id']}", + display_name=user_info.get("name") or user_info["email"], + metadata={ + "avatar": user_info.get("picture"), + "email": user_info["email"], + }, + ) + elif provider_name == "azure": + default_user = User( + identifier=f"azure:{user_info['id']}", + display_name=user_info.get("displayName") or user_info["userPrincipalName"], + metadata={"email": user_info.get("userPrincipalName")}, + ) + elif provider_name == "okta": + default_user = User( + identifier=f"okta:{user_info['sub']}", + display_name=user_info.get("name") or user_info["preferred_username"], + metadata={"email": user_info.get("email")}, + ) + else: + default_user = User( + identifier=f"{provider_name}:{user_info.get('id', user_info.get('sub', 'unknown'))}", + display_name=user_info.get("name", user_info.get("login", "Unknown")), + ) + + cb = _oauth_callbacks.get(provider_name) + if cb: + try: + user = await cb(provider_name, token_data, user_info, default_user) + if user is None: + return JSONResponse( + {"error": "Authentication denied by application"}, + status_code=401, + ) + except Exception as e: + return JSONResponse({"error": f"OAuth callback error: {e!s}"}, status_code=500) + else: + user = default_user + + token = create_token(user.identifier) + if "application/json" in request.headers.get("accept", ""): + return JSONResponse({"user": user.to_dict(), "token": token}) + return_url = state_data.get("return_url", "/") + return RedirectResponse(url=f"{return_url}?token={token}", status_code=302) + except ImportError: + return JSONResponse({"error": "OAuth providers not available"}, status_code=500) + except Exception as e: + return JSONResponse({"error": f"OAuth callback error: {e!s}"}, status_code=500) + + +async def create_thread_share_handler(request: Request) -> JSONResponse: + """Create a share token for a thread.""" + from praisonaiui.auth import validate_token + + thread_id = request.path_params["thread_id"] + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return JSONResponse({"error": "Authentication required"}, status_code=401) + user_id = validate_token(auth_header[7:]) + if not user_id: + return JSONResponse({"error": "Invalid token"}, status_code=401) + session = await _datastore.get_session(thread_id) + if not session: + return JSONResponse({"error": "Thread not found"}, status_code=404) + try: + from praisonaiui.features.sharing import create_share_token, get_share_url + + share_token = create_share_token(thread_id, user_id) + base_url = str(request.base_url).rstrip("/") + share_url = get_share_url(share_token, base_url) + return JSONResponse({"url": share_url, "token": share_token}) + except ImportError: + return JSONResponse({"error": "Sharing feature not available"}, status_code=500) + + +async def revoke_thread_share_handler(request: Request) -> JSONResponse: + """Revoke share token(s) for a thread.""" + from praisonaiui.auth import validate_token + + thread_id = request.path_params["thread_id"] + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return JSONResponse({"error": "Authentication required"}, status_code=401) + user_id = validate_token(auth_header[7:]) + if not user_id: + return JSONResponse({"error": "Invalid token"}, status_code=401) + try: + from praisonaiui.features.sharing import revoke_share_token + + return JSONResponse( + {"revoked": revoke_share_token(thread_id, user_id), "thread_id": thread_id} + ) + except ImportError: + return JSONResponse({"error": "Sharing feature not available"}, status_code=500) + + +async def view_shared_thread_handler(request: Request) -> Response: + """View a shared thread (read-only).""" + share_token = request.path_params["token"] + try: + from praisonaiui.auth import User, _users, validate_token + from praisonaiui.features.sharing import ( + check_shared_thread_access, + get_thread_by_share_token, + ) + + thread_id = get_thread_by_share_token(share_token) + if not thread_id: + return JSONResponse({"error": "Invalid or expired share link"}, status_code=404) + session = await _datastore.get_session(thread_id) + if not session: + return JSONResponse({"error": "Thread not found"}, status_code=404) + + user = None + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + uid = validate_token(auth_header[7:]) + if uid: + if uid in _users: + user_data = _users[uid] + user = User( + identifier=user_data["id"], + display_name=user_data["username"], + metadata={}, + ) + else: + user = User(identifier=uid, display_name=uid) + + if not await check_shared_thread_access(thread_id, user): + return JSONResponse({"error": "Access denied"}, status_code=403) + + if "application/json" in request.headers.get("accept", ""): + messages = await _datastore.get_messages(thread_id) + return JSONResponse( + { + "thread_id": thread_id, + "session": session, + "messages": messages, + "read_only": True, + } + ) + return HTMLResponse(_build_html(_effective_style)) + except ImportError: + return JSONResponse({"error": "Sharing feature not available"}, status_code=500) + + async def api_gateway_status(request: Request) -> JSONResponse: """GET /api/gateway/status — real gateway connectivity and agent info.""" try: @@ -2359,6 +2591,24 @@ async def _route_manifest_json(request: Request) -> JSONResponse: Route("/api/mcp/servers", api_mcp_servers, methods=["GET"]), Route("/api/mcp/connect", api_mcp_connect, methods=["POST"]), Route("/api/mcp/disconnect/{server_id}", api_mcp_disconnect, methods=["POST"]), + # OAuth + thread-sharing (issue #22) + Route("/api/auth/oauth/{provider}", oauth_authorize_handler, methods=["GET"]), + Route( + "/api/auth/oauth/{provider}/callback", + oauth_callback_handler, + methods=["GET"], + ), + Route( + "/api/threads/{thread_id}/share", + create_thread_share_handler, + methods=["POST"], + ), + Route( + "/api/threads/{thread_id}/unshare", + revoke_thread_share_handler, + methods=["POST"], + ), + Route("/shared/{token}", view_shared_thread_handler, methods=["GET"]), Route("/api/gateway/status", api_gateway_status, methods=["GET"]), # Frontend config JSON (dynamic fallback — static files override if present) Route("/ui-config.json", _ui_config_json, methods=["GET"]), diff --git a/tests/unit/test_oauth.py b/tests/unit/test_oauth.py new file mode 100644 index 0000000..3dafd5c --- /dev/null +++ b/tests/unit/test_oauth.py @@ -0,0 +1,339 @@ +"""Tests for OAuth authentication functionality.""" + +import os +from unittest.mock import AsyncMock, Mock, patch +import pytest +from starlette.applications import Starlette +from starlette.testclient import TestClient + +from praisonaiui.auth import User, oauth_callback, _oauth_callbacks +from praisonaiui.oauth_providers import ( + GitHubProvider, + GoogleProvider, + AzureADProvider, + OktaProvider, + create_oauth_provider, + get_oauth_config_from_env, + create_oauth_state, + validate_oauth_state, +) + + +class TestOAuthProviders: + """Test OAuth provider implementations.""" + + def test_github_provider(self): + """Test GitHub provider URLs and configuration.""" + provider = GitHubProvider( + client_id="test_client_id", + client_secret="test_secret", + redirect_uri="http://localhost:8000/auth/callback" + ) + + assert provider.scopes == ["user:email"] + + state = "test_state_123" + auth_url = provider.get_authorize_url(state) + + assert "github.com/login/oauth/authorize" in auth_url + assert "client_id=test_client_id" in auth_url + assert f"state={state}" in auth_url + assert "scope=user%3Aemail" in auth_url + + def test_google_provider(self): + """Test Google provider URLs and configuration.""" + provider = GoogleProvider( + client_id="test_client_id", + client_secret="test_secret", + redirect_uri="http://localhost:8000/auth/callback" + ) + + assert provider.scopes == ["openid", "email", "profile"] + + state = "test_state_456" + auth_url = provider.get_authorize_url(state) + + assert "accounts.google.com/o/oauth2/v2/auth" in auth_url + assert "client_id=test_client_id" in auth_url + assert f"state={state}" in auth_url + + def test_azure_provider(self): + """Test Azure AD provider URLs and configuration.""" + provider = AzureADProvider( + client_id="test_client_id", + client_secret="test_secret", + redirect_uri="http://localhost:8000/auth/callback", + tenant_id="common" + ) + + assert provider.scopes == ["openid", "profile", "email"] + + state = "test_state_789" + auth_url = provider.get_authorize_url(state) + + assert "login.microsoftonline.com/common/oauth2/v2.0/authorize" in auth_url + assert "client_id=test_client_id" in auth_url + assert f"state={state}" in auth_url + + def test_okta_provider(self): + """Test Okta provider URLs and configuration.""" + provider = OktaProvider( + client_id="test_client_id", + client_secret="test_secret", + redirect_uri="http://localhost:8000/auth/callback", + domain="https://dev-123456.okta.com" + ) + + assert provider.scopes == ["openid", "profile", "email"] + + state = "test_state_okta" + auth_url = provider.get_authorize_url(state) + + assert "dev-123456.okta.com/oauth2/v1/authorize" in auth_url + assert "client_id=test_client_id" in auth_url + assert f"state={state}" in auth_url + + def test_provider_factory(self): + """Test OAuth provider factory.""" + # Test GitHub + provider = create_oauth_provider( + "github", "client_id", "secret", "redirect_uri" + ) + assert isinstance(provider, GitHubProvider) + + # Test Google + provider = create_oauth_provider( + "google", "client_id", "secret", "redirect_uri" + ) + assert isinstance(provider, GoogleProvider) + + # Test Azure + provider = create_oauth_provider( + "azure", "client_id", "secret", "redirect_uri", tenant_id="test" + ) + assert isinstance(provider, AzureADProvider) + + # Test Okta + provider = create_oauth_provider( + "okta", "client_id", "secret", "redirect_uri", domain="https://test.okta.com" + ) + assert isinstance(provider, OktaProvider) + + # Test unknown provider + with pytest.raises(ValueError, match="Unsupported OAuth provider"): + create_oauth_provider("unknown", "client_id", "secret", "redirect_uri") + + +class TestOAuthEnvironmentConfig: + """Test OAuth configuration from environment variables.""" + + def test_get_github_config_from_env(self): + """Test GitHub configuration from environment.""" + with patch.dict(os.environ, { + "AIUI_OAUTH_GITHUB_CLIENT_ID": "gh_client_123", + "AIUI_OAUTH_GITHUB_CLIENT_SECRET": "gh_secret_456", + "AIUI_OAUTH_GITHUB_REDIRECT_URI": "http://localhost:8000/callback" + }): + config = get_oauth_config_from_env("github") + assert config == { + "client_id": "gh_client_123", + "client_secret": "gh_secret_456", + "redirect_uri": "http://localhost:8000/callback" + } + + def test_get_azure_config_from_env(self): + """Test Azure configuration from environment.""" + with patch.dict(os.environ, { + "AIUI_OAUTH_AZURE_CLIENT_ID": "az_client_123", + "AIUI_OAUTH_AZURE_CLIENT_SECRET": "az_secret_456", + "AIUI_OAUTH_AZURE_TENANT_ID": "my-tenant" + }): + config = get_oauth_config_from_env("azure") + assert config == { + "client_id": "az_client_123", + "client_secret": "az_secret_456", + "tenant_id": "my-tenant" + } + + def test_get_missing_config_from_env(self): + """Test behavior when OAuth config is missing.""" + with patch.dict(os.environ, {}, clear=True): + config = get_oauth_config_from_env("github") + assert config is None + + +class TestOAuthStateManagement: + """Test OAuth state token management.""" + + def test_create_and_validate_state(self): + """Test state creation and validation.""" + state = create_oauth_state("github", "/dashboard") + assert len(state) > 0 + + state_data = validate_oauth_state(state) + assert state_data is not None + assert state_data["provider"] == "github" + assert state_data["return_url"] == "/dashboard" + + # State should be consumed after validation + state_data_again = validate_oauth_state(state) + assert state_data_again is None + + def test_validate_invalid_state(self): + """Test validation of invalid state.""" + invalid_state = "invalid_state_token" + state_data = validate_oauth_state(invalid_state) + assert state_data is None + + +class TestOAuthCallbacks: + """Test OAuth callback functionality.""" + + def test_oauth_callback_decorator(self): + """Test OAuth callback decorator registration.""" + # Clear existing callbacks + _oauth_callbacks.clear() + + @oauth_callback("github") + async def handle_github_oauth(provider, token, raw_user, default_user): + return User( + identifier=f"github:{raw_user['login']}", + display_name=raw_user["name"], + metadata={"login": raw_user["login"]} + ) + + assert "github" in _oauth_callbacks + assert _oauth_callbacks["github"] == handle_github_oauth + + def test_oauth_callback_deny(self): + """Test OAuth callback that denies authentication.""" + # Clear existing callbacks + _oauth_callbacks.clear() + + @oauth_callback("google") + async def handle_google_oauth(provider, token, raw_user, default_user): + # Deny authentication for test + return None + + assert "google" in _oauth_callbacks + callback = _oauth_callbacks["google"] + + # Mock data + token_data = {"access_token": "test_token"} + user_info = {"id": "123", "email": "test@example.com"} + default_user = User(identifier="google:123", display_name="Test User") + + # Test that callback returns None (denies auth) + import asyncio + result = asyncio.run(callback("google", token_data, user_info, default_user)) + assert result is None + + +@pytest.mark.asyncio +class TestOAuthIntegration: + """Test OAuth integration with server routes.""" + + @patch('praisonaiui.oauth_providers.get_http_client') + async def test_github_oauth_round_trip(self, mock_get_client): + """Test complete GitHub OAuth flow.""" + from praisonaiui.server import create_app + + # Clear existing callbacks to avoid interference from other tests + _oauth_callbacks.clear() + + # Register a mock OAuth callback for this test + @oauth_callback("github") + async def test_github_callback(provider, token, raw_user, default_user): + return User( + identifier=f"github:{raw_user['login']}", + display_name=raw_user["name"], + metadata={"login": raw_user["login"]} + ) + + # Mock OAuth provider responses + mock_response = Mock() + mock_response.json.return_value = {"access_token": "github_token_123"} + mock_response.raise_for_status.return_value = None + + mock_user_response = Mock() + mock_user_response.json.return_value = { + "login": "testuser", + "name": "Test User", + "avatar_url": "https://github.com/testuser.png" + } + mock_user_response.raise_for_status.return_value = None + + mock_client_instance = AsyncMock() + mock_client_instance.post.return_value = mock_response + mock_client_instance.get.return_value = mock_user_response + mock_get_client.return_value = mock_client_instance + + # Set up OAuth config + with patch.dict(os.environ, { + "AIUI_OAUTH_GITHUB_CLIENT_ID": "test_client", + "AIUI_OAUTH_GITHUB_CLIENT_SECRET": "test_secret" + }): + app = create_app() + client = TestClient(app) + + # Test authorization redirect + response = client.get("/api/auth/oauth/github", follow_redirects=False) + assert response.status_code == 302 + assert "github.com/login/oauth/authorize" in response.headers["location"] + + # Extract state from redirect URL + location = response.headers["location"] + state = location.split("state=")[1].split("&")[0] if "state=" in location else "test_state" + + # Test callback + with patch('praisonaiui.oauth_providers.validate_oauth_state') as mock_validate: + mock_validate.return_value = {"provider": "github", "return_url": "/"} + + callback_response = client.get( + f"/api/auth/oauth/github/callback?code=test_code&state={state}", + headers={"Accept": "application/json"} + ) + + if callback_response.status_code != 200: + print(f"Callback error: {callback_response.status_code}") + print(f"Callback response: {callback_response.content}") + + assert callback_response.status_code == 200 + data = callback_response.json() + assert "user" in data + assert "token" in data + assert data["user"]["identifier"] == "github:testuser" + + async def test_oauth_callback_with_custom_handler(self): + """Test OAuth callback with custom user handler.""" + # Clear existing callbacks + _oauth_callbacks.clear() + + @oauth_callback("github") + async def custom_github_handler(provider, token, raw_user, default_user): + # Custom logic - only allow specific users + if raw_user["login"] == "allowed_user": + return User( + identifier=f"custom:{raw_user['login']}", + display_name=f"Custom {raw_user['name']}", + metadata={"custom": True} + ) + return None # Deny other users + + # Test allowed user + token_data = {"access_token": "test"} + user_info = {"login": "allowed_user", "name": "Allowed User"} + default_user = User(identifier="github:allowed_user", display_name="Allowed User") + + result = await custom_github_handler("github", token_data, user_info, default_user) + assert result is not None + assert result.identifier == "custom:allowed_user" + assert result.display_name == "Custom Allowed User" + assert result.metadata["custom"] is True + + # Test denied user + user_info = {"login": "denied_user", "name": "Denied User"} + default_user = User(identifier="github:denied_user", display_name="Denied User") + + result = await custom_github_handler("github", token_data, user_info, default_user) + assert result is None \ No newline at end of file diff --git a/tests/unit/test_sharing.py b/tests/unit/test_sharing.py new file mode 100644 index 0000000..294149c --- /dev/null +++ b/tests/unit/test_sharing.py @@ -0,0 +1,409 @@ +"""Tests for thread sharing functionality.""" + +import pytest +from unittest.mock import AsyncMock, Mock, patch +from starlette.applications import Starlette +from starlette.testclient import TestClient + +from praisonaiui.auth import User, create_token +from praisonaiui.features.sharing import ( + on_shared_thread_view, + create_share_token, + get_thread_by_share_token, + revoke_share_token, + list_share_tokens, + check_shared_thread_access, + get_share_url, + _on_shared_thread_view_callback, + _share_tokens, +) + + +class TestSharingTokens: + """Test share token creation and management.""" + + def setup_method(self): + """Clear share tokens before each test.""" + _share_tokens.clear() + + def test_create_share_token(self): + """Test share token creation.""" + thread_id = "session_123" + user_id = "user_456" + + token = create_share_token(thread_id, user_id) + + assert len(token) > 0 + assert token in _share_tokens + + token_data = _share_tokens[token] + assert token_data["thread_id"] == thread_id + assert token_data["created_by"] == user_id + assert "created_at" in token_data + + def test_get_thread_by_share_token(self): + """Test retrieving thread ID from share token.""" + thread_id = "session_789" + user_id = "user_123" + + token = create_share_token(thread_id, user_id) + + # Valid token + result = get_thread_by_share_token(token) + assert result == thread_id + + # Invalid token + result = get_thread_by_share_token("invalid_token") + assert result is None + + def test_revoke_share_token(self): + """Test share token revocation.""" + thread_id = "session_456" + user_id = "user_789" + other_user_id = "other_user" + + token = create_share_token(thread_id, user_id) + + # Only creator can revoke + revoked = revoke_share_token(thread_id, other_user_id) + assert revoked is False + assert token in _share_tokens + + # Creator can revoke + revoked = revoke_share_token(thread_id, user_id) + assert revoked is True + assert token not in _share_tokens + + def test_list_share_tokens(self): + """Test listing share tokens by user.""" + user1_id = "user_1" + user2_id = "user_2" + + token1 = create_share_token("thread_1", user1_id) + token2 = create_share_token("thread_2", user1_id) + token3 = create_share_token("thread_3", user2_id) + + # User 1's tokens + user1_tokens = list_share_tokens(user1_id) + assert len(user1_tokens) == 2 + token_values = [t["token"] for t in user1_tokens] + assert token1 in token_values + assert token2 in token_values + + # User 2's tokens + user2_tokens = list_share_tokens(user2_id) + assert len(user2_tokens) == 1 + assert user2_tokens[0]["token"] == token3 + + def test_get_share_url(self): + """Test share URL generation.""" + token = "test_token_123" + + # Without base URL + url = get_share_url(token) + assert url == "/shared/test_token_123" + + # With base URL + url = get_share_url(token, "https://example.com") + assert url == "https://example.com/shared/test_token_123" + + # With base URL ending in slash + url = get_share_url(token, "https://example.com/") + assert url == "https://example.com/shared/test_token_123" + + +class TestSharingAccessControl: + """Test sharing access control via callbacks.""" + + def setup_method(self): + """Clear callbacks before each test.""" + global _on_shared_thread_view_callback + _on_shared_thread_view_callback = None + + def test_on_shared_thread_view_decorator(self): + """Test shared thread view decorator registration.""" + from praisonaiui.features.sharing import _on_shared_thread_view_callback + + @on_shared_thread_view + async def handle_share_view(thread_id, viewer): + return True + + # Import fresh reference to check the module variable was set + import praisonaiui.features.sharing as sharing_module + assert sharing_module._on_shared_thread_view_callback == handle_share_view + + @pytest.mark.asyncio + async def test_check_shared_thread_access_allow(self): + """Test shared thread access check - allow.""" + @on_shared_thread_view + async def allow_access(thread_id, viewer): + return True + + result = await check_shared_thread_access("thread_123", None) + assert result is True + + user = User(identifier="user_123", display_name="Test User") + result = await check_shared_thread_access("thread_456", user) + assert result is True + + @pytest.mark.asyncio + async def test_check_shared_thread_access_deny(self): + """Test shared thread access check - deny.""" + @on_shared_thread_view + async def deny_access(thread_id, viewer): + return False + + result = await check_shared_thread_access("thread_789", None) + assert result is False + + @pytest.mark.asyncio + async def test_check_shared_thread_access_conditional(self): + """Test conditional access control.""" + @on_shared_thread_view + async def conditional_access(thread_id, viewer): + # Allow logged-in users, deny anonymous + return viewer is not None + + # Anonymous user + result = await check_shared_thread_access("thread_abc", None) + assert result is False + + # Logged-in user + user = User(identifier="user_def", display_name="Logged User") + result = await check_shared_thread_access("thread_abc", user) + assert result is True + + @pytest.mark.asyncio + async def test_check_shared_thread_access_no_handler(self): + """Test access check with no handler registered (safe default).""" + # No handler registered - should default to deny + result = await check_shared_thread_access("thread_xyz", None) + assert result is False + + @pytest.mark.asyncio + async def test_check_shared_thread_access_callback_error(self): + """Test access check when callback raises exception.""" + @on_shared_thread_view + async def error_callback(thread_id, viewer): + raise Exception("Callback error") + + # Should fail safely to deny access + result = await check_shared_thread_access("thread_error", None) + assert result is False + + +@pytest.mark.asyncio +class TestSharingRoutes: + """Test sharing HTTP routes.""" + + def setup_method(self): + """Clear share tokens before each test.""" + _share_tokens.clear() + + async def test_create_share_endpoint(self): + """Test POST /api/threads/{id}/share endpoint.""" + from praisonaiui.server import create_app + from praisonaiui.datastore import MemoryDataStore + + # Set up test datastore + datastore = MemoryDataStore() + session = await datastore.create_session("test_session") + + app = create_app() + + with patch('praisonaiui.server._datastore', datastore): + client = TestClient(app) + + # Create test token + token = create_token("test_user") + + response = client.post( + f"/api/threads/{session['id']}/share", + headers={"Authorization": f"Bearer {token}"} + ) + + assert response.status_code == 200 + data = response.json() + assert "url" in data + assert "token" in data + assert "/shared/" in data["url"] + + async def test_create_share_unauthenticated(self): + """Test share creation without authentication.""" + from praisonaiui.server import create_app + + app = create_app() + client = TestClient(app) + + response = client.post("/api/threads/test_session/share") + assert response.status_code == 401 + + async def test_create_share_invalid_thread(self): + """Test share creation for non-existent thread.""" + from praisonaiui.server import create_app + from praisonaiui.datastore import MemoryDataStore + + app = create_app() + + with patch('praisonaiui.server._datastore', MemoryDataStore()): + client = TestClient(app) + + # Create test token + token = create_token("test_user") + + response = client.post( + "/api/threads/nonexistent_session/share", + headers={"Authorization": f"Bearer {token}"} + ) + + assert response.status_code == 404 + + async def test_revoke_share_endpoint(self): + """Test POST /api/threads/{id}/unshare endpoint.""" + from praisonaiui.server import create_app + from praisonaiui.datastore import MemoryDataStore + + # Set up test datastore + datastore = MemoryDataStore() + session = await datastore.create_session("test_session_revoke") + + app = create_app() + + with patch('praisonaiui.server._datastore', datastore): + client = TestClient(app) + + # Create test token and share + token = create_token("test_user") + share_token = create_share_token(session['id'], "test_user") + + response = client.post( + f"/api/threads/{session['id']}/unshare", + headers={"Authorization": f"Bearer {token}"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["revoked"] is True + assert data["thread_id"] == session['id'] + + # Verify token was revoked + assert get_thread_by_share_token(share_token) is None + + async def test_view_shared_thread_endpoint(self): + """Test GET /shared/{token} endpoint.""" + from praisonaiui.server import create_app + from praisonaiui.datastore import MemoryDataStore + + # Set up test datastore + datastore = MemoryDataStore() + session = await datastore.create_session("test_session_view") + await datastore.add_message(session['id'], { + "role": "user", + "content": "Hello", + "timestamp": "2024-01-01T00:00:00Z" + }) + await datastore.add_message(session['id'], { + "role": "assistant", + "content": "Hi there!", + "timestamp": "2024-01-01T00:00:01Z" + }) + + # Set up sharing callback + @on_shared_thread_view + async def allow_all(thread_id, viewer): + return True + + app = create_app() + + with patch('praisonaiui.server._datastore', datastore): + client = TestClient(app) + + # Create share token + share_token = create_share_token(session['id'], "test_user") + + # Test JSON API access + response = client.get( + f"/shared/{share_token}", + headers={"Accept": "application/json"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["thread_id"] == session['id'] + assert data["read_only"] is True + assert len(data["messages"]) == 2 + assert data["messages"][0]["content"] == "Hello" + assert data["messages"][1]["content"] == "Hi there!" + + async def test_view_shared_thread_access_denied(self): + """Test shared thread access denial.""" + from praisonaiui.server import create_app + from praisonaiui.datastore import MemoryDataStore + + # Set up test datastore + datastore = MemoryDataStore() + session = await datastore.create_session("test_session_denied") + + # Set up sharing callback that denies access + @on_shared_thread_view + async def deny_all(thread_id, viewer): + return False + + app = create_app() + + with patch('praisonaiui.server._datastore', datastore): + client = TestClient(app) + + # Create share token + share_token = create_share_token(session['id'], "test_user") + + response = client.get( + f"/shared/{share_token}", + headers={"Accept": "application/json"} + ) + + assert response.status_code == 403 + + async def test_view_shared_thread_invalid_token(self): + """Test shared thread with invalid token.""" + from praisonaiui.server import create_app + + app = create_app() + client = TestClient(app) + + response = client.get( + "/shared/invalid_token_123", + headers={"Accept": "application/json"} + ) + + assert response.status_code == 404 + + async def test_view_shared_thread_html(self): + """Test shared thread HTML view.""" + from praisonaiui.server import create_app + from praisonaiui.datastore import MemoryDataStore + + # Set up test datastore + datastore = MemoryDataStore() + session = await datastore.create_session("test_session_html") + + # Set up sharing callback + @on_shared_thread_view + async def allow_all(thread_id, viewer): + return True + + app = create_app() + + with patch('praisonaiui.server._datastore', datastore): + client = TestClient(app) + + # Create share token + share_token = create_share_token(session['id'], "test_user") + + # Test HTML access (browser) + response = client.get(f"/shared/{share_token}") + + assert response.status_code == 200 + assert "text/html" in response.headers.get("content-type", "") + # Should return the SPA shell for frontend to handle + assert "" in response.text.lower() \ No newline at end of file