diff --git a/docs/Cross_Endpoint_Routing.md b/docs/Cross_Endpoint_Routing.md index 6d059636..8baed1a9 100644 --- a/docs/Cross_Endpoint_Routing.md +++ b/docs/Cross_Endpoint_Routing.md @@ -342,8 +342,8 @@ graph TD A["Function Call"] -->|"intercepts stub layer"| B["ProductionWrapper"] B -->|"load service configuration"| C["ServiceRegistry"] - C -->|"if not cached"| D["DirectoryClient"] - D -->|"query mothership API"| E["Directory
Endpoint URLs"] + C -->|"if not cached"| D["Manifest Endpoint
/manifest"] + D -->|"query mothership"| E["Endpoint URLs
From Deployed
Resources"] E -->|"cache result
TTL 300s"| C C -->|"lookup in manifest
flash_manifest.json"| F{"Routing
Decision"} @@ -358,7 +358,7 @@ graph TD K --> L["Return Response
base64 → cloudpickle"] L --> M["Deserialized Result"] - N["Error Handling:
- RemoteExecutionError
- SerializationError
- DirectoryUnavailableError"] -.-> H + N["Error Handling:
- RemoteExecutionError
- SerializationError
- ManifestServiceUnavailableError"] -.-> H N -.-> I N -.-> J @@ -450,11 +450,11 @@ class ServiceRegistry: """Service discovery and routing for cross-endpoint function calls.""" def __init__(self, manifest_path: Optional[Path] = None): - """Initialize with manifest and optional directory client.""" + """Initialize with manifest and directory caching.""" self._load_manifest(manifest_path) - self._directory_client = DirectoryClient(...) - self._directory = {} # Cached endpoint URLs + self._directory = {} # Cached endpoint URLs from /manifest self._directory_lock = asyncio.Lock() + self._directory_loaded_at = 0 def get_resource_for_function(self, func_name: str) -> Optional[ServerlessResource]: """Get resource config for function from manifest.""" @@ -463,16 +463,25 @@ class ServiceRegistry: # - Explicitly set to null in manifest # Returns ServerlessResource if mapped in manifest - config = self._manifest["functions"].get(func_name) + config = self._manifest["function_registry"].get(func_name) return self._resolve_resource(config) async def _ensure_directory_loaded(self) -> None: - """Load directory from mothership with caching (TTL 300s).""" + """Load manifest directory from mothership with caching (TTL 300s). + + Queries the /manifest endpoint on FLASH_MOTHERSHIP_URL. + """ if self._is_directory_fresh(): return async with self._directory_lock: - self._directory = await self._directory_client.get_directory() + # Query mothership /manifest endpoint + mothership_url = os.getenv("FLASH_MOTHERSHIP_URL") + if not mothership_url: + return # Directory unavailable, graceful fallback + + response = await self._http_client.get(f"{mothership_url}/manifest") + self._directory = response.json().get("manifest", {}) self._directory_loaded_at = time.time() ``` @@ -504,35 +513,46 @@ class ServiceRegistry: - Thread-safe with `asyncio.Lock()` - Graceful fallback if directory unavailable -#### 3. DirectoryClient +#### 3. StateManagerClient -**Location**: `src/tetra_rp/runtime/directory_client.py` +**Location**: `src/tetra_rp/runtime/state_manager_client.py` -HTTP client for mothership directory service: +HTTP client for State Manager API (used by mothership auto-provisioning): ```python -class DirectoryClient: - """HTTP client for querying mothership directory. +class StateManagerClient: + """HTTP client for State Manager API. - The directory maps resource_config names to their endpoint URLs. - Example: {"gpu_config": "https://api.runpod.io/v2/abc123"} + The State Manager persists manifest state and provides reconciliation + for detecting new, changed, and removed resources. """ - async def get_directory(self) -> Dict[str, str]: - """Fetch endpoint directory from mothership. + async def get_persisted_manifest( + self, mothership_id: str + ) -> Optional[Dict[str, Any]]: + """Fetch persisted manifest from State Manager. Returns: - Dictionary mapping resource_config_name → endpoint_url. - Example: {"gpu_config": "https://api.runpod.io/v2/abc123"} + Manifest dict or None if not found (first boot). Raises: - DirectoryUnavailableError: If directory service unavailable after retries. + ManifestServiceUnavailableError: If State Manager unavailable. """ - # Queries {mothership_url}/directory endpoint with retry logic + # Queries {base_url}/api/v1/flash/manifests/{mothership_id} + + async def update_resource_state( + self, + mothership_id: str, + resource_name: str, + resource_data: Dict[str, Any], + ) -> None: + """Update resource entry in State Manager after deployment.""" + # Queries {base_url}/api/v1/flash/manifests/{mothership_id}/resources/{resource_name} ``` **Configuration**: -- Mothership URL from `FLASH_MOTHERSHIP_URL` env var +- Base URL: `https://api.runpod.io` (default, configurable) +- Authentication: Bearer token using RUNPOD_API_KEY env var - HTTP timeout: 10 seconds (via `DEFAULT_REQUEST_TIMEOUT`) - Retry logic: Exponential backoff with `DEFAULT_MAX_RETRIES` attempts (default: 3) - Uses `httpx` library for async HTTP requests @@ -561,8 +581,8 @@ class ManifestError(FlashRuntimeError): """Raised when manifest is invalid, missing, or has unexpected structure.""" pass -class DirectoryUnavailableError(FlashRuntimeError): - """Raised when directory service is unavailable.""" +class ManifestServiceUnavailableError(FlashRuntimeError): + """Raised when manifest service (mothership /manifest endpoint) is unavailable.""" pass ``` @@ -576,8 +596,8 @@ except SerializationError as e: logger.error(f"Serialization failed: {e}") except ManifestError as e: logger.error(f"Manifest configuration error: {e}") -except DirectoryUnavailableError as e: - logger.warning(f"Directory unavailable, using fallback") +except ManifestServiceUnavailableError as e: + logger.warning(f"Manifest service unavailable, using fallback") ``` ### Integration Points @@ -667,7 +687,7 @@ Manifest Lookup (resource found) ↓ Ensure Directory Loaded ↓ -DirectoryClient.get_endpoints() +Query /manifest Endpoint (from mothership) ↓ Get Remote Endpoint URL ↓ @@ -755,29 +775,35 @@ class JsonSerializer: 2. Update ProductionWrapper to select serializer based on config 3. Add tests for new format -#### Adding New Directory Backends +#### Customizing Directory Loading -To support directories other than mothership: +To support alternate directory backends instead of the mothership /manifest endpoint: -1. Create client class with `get_directory()` method: +1. Subclass ServiceRegistry and override `_ensure_directory_loaded()`: ```python -class CustomDirectoryClient: - async def get_directory(self) -> Dict[str, str]: - """Fetch directory mapping resource_config_name → endpoint_url.""" +class CustomDirectoryRegistry(ServiceRegistry): + async def _ensure_directory_loaded(self) -> None: + """Load directory from custom backend instead of /manifest.""" + if self._is_directory_fresh(): + return + + async with self._directory_lock: + # Custom directory loading logic + self._directory = await self._load_custom_directory() + self._directory_loaded_at = time.time() + + async def _load_custom_directory(self) -> Dict[str, str]: + """Load directory from custom backend.""" # Implementation specific to backend return {"resource_name": "https://endpoint.url"} ``` -2. Update ServiceRegistry to accept and use client in constructor: +2. Use custom registry in ProductionWrapper: ```python -registry = ServiceRegistry( - manifest_path=Path("manifest.json"), - directory_client=CustomDirectoryClient(...) -) +registry = CustomDirectoryRegistry(manifest_path=Path("manifest.json")) +wrapper = ProductionWrapper(registry) ``` -3. Update environment variable handling if needed (CustomDirectoryClient can read from env vars) - #### Adding Routing Policies To implement routing logic beyond manifest: @@ -806,16 +832,22 @@ class RoutingPolicy: **ServiceRegistry Tests** (`tests/unit/runtime/test_service_registry.py`): - Manifest loading - Resource lookup -- Directory caching +- Directory caching from /manifest - TTL expiry - Lock behavior under concurrency -**DirectoryClient Tests** (`tests/unit/runtime/test_directory_client.py`): -- Successful HTTP requests +**StateManagerClient Tests** (`tests/unit/runtime/test_state_manager_client.py`): +- Successful manifest fetch +- Manifest updates and deletions - Error handling -- Retry logic -- Timeout handling -- URL validation +- Retry logic with exponential backoff +- HTTP timeout handling + +**MothershipProvisioner Tests** (`tests/unit/runtime/test_mothership_provisioner.py`): +- Manifest reconciliation +- Drift detection via config hashing +- Resource creation from manifest +- Directory mapping extraction **ProductionWrapper Tests** (`tests/unit/runtime/test_production_wrapper.py`): - Local execution routing @@ -878,9 +910,9 @@ Enable debug logging to trace routing decisions: import logging logging.basicConfig(level=logging.DEBUG) -# ProductionWrapper logs -# ServiceRegistry logs -# DirectoryClient logs +# ProductionWrapper logs routing decisions +# ServiceRegistry logs manifest and directory queries +# StateManagerClient logs State Manager API requests ``` #### Common Debug Scenarios @@ -890,7 +922,7 @@ logging.basicConfig(level=logging.DEBUG) # Check manifest print(registry._manifest) -# Check directory +# Check directory (from /manifest endpoint) print(registry._directory) # Check resource lookup @@ -908,16 +940,18 @@ except Exception as e: print(f"Not serializable: {e}") ``` -**Directory unavailable**: +**/manifest endpoint unavailable**: ```python # Check environment variables import os print(f"FLASH_MOTHERSHIP_URL: {os.getenv('FLASH_MOTHERSHIP_URL')}") print(f"RUNPOD_ENDPOINT_ID: {os.getenv('RUNPOD_ENDPOINT_ID')}") -# Check directory client directly -client = DirectoryClient(mothership_url=...) -endpoints = await client.get_endpoints() +# Check /manifest endpoint directly +import httpx +async with httpx.AsyncClient() as client: + response = await client.get(f"{mothership_url}/manifest") + print(response.json()) ``` ## Key Implementation Highlights @@ -930,6 +964,232 @@ endpoints = await client.get_endpoints() 4. **Thread-Safe Async**: Proper `asyncio.Lock()` usage for concurrent operations 5. **Clear Error Hierarchy**: Custom exceptions provide actionable error context +## Mothership Auto-Provisioning + +### Overview + +Mothership auto-provisioning automates the deployment of child endpoints when the mothership boots. Instead of manually deploying child resources, the mothership reads its `flash_manifest.json`, compares against the persisted manifest in State Manager, and automatically deploys, updates, or removes child resources as needed. + +### How It Works + +#### 1. Mothership Identification + +When a LoadBalancerSlsResource is deployed as the mothership, the system automatically sets: +```python +env["FLASH_IS_MOTHERSHIP"] = "true" +``` + +This environment variable signals to the mothership that it should auto-provision child resources on boot. + +#### 2. Boot Sequence + +When the mothership starts: + +1. **Lifespan Startup Hook**: The FastAPI lifespan context manager starts +2. **Mothership Check**: System checks if `FLASH_IS_MOTHERSHIP=true` +3. **Background Task**: Spawns non-blocking provisioning task via `asyncio.create_task()` +4. **FastAPI Server**: Starts serving requests immediately (not blocked by provisioning) +5. **Directory Available**: `/manifest` endpoint returns partial results during provisioning + +#### 3. Manifest Reconciliation + +The mothership compares local manifest with State Manager's persisted manifest: + +**New Resources** (in local, not in State Manager): +- Created with `ResourceManager.get_or_deploy_resource()` +- `FLASH_MOTHERSHIP_URL` env var set on child +- State Manager updated with resource entry (hash, endpoint_url, status) + +**Changed Resources** (different config hash): +- Updated with `ResourceManager.get_or_deploy_resource()` +- Config hash recomputed and State Manager updated + +**Removed Resources** (in State Manager, not in local): +- Undeployed with `ResourceManager.undeploy_resource()` +- Removed from State Manager + +**Unchanged Resources** (same config hash): +- Skipped (idempotent behavior - no unnecessary deployments) + +**LoadBalancer Resources** (LoadBalancerSlsResource, LiveLoadBalancer): +- Automatically skipped (don't deploy the mothership as a child) + +### Configuration + +#### Environment Variables + +The mothership uses: +- `RUNPOD_ENDPOINT_ID`: Mothership's endpoint ID (required for URL construction) +- `FLASH_IS_MOTHERSHIP`: Set to `"true"` to trigger auto-provisioning +- `RUNPOD_API_KEY`: Used for State Manager API authentication + +#### State Manager API + +The mothership persists manifest state via HTTP API: + +**Endpoints**: +- `GET /api/v1/flash/manifests/{mothership_id}` - Fetch persisted manifest +- `PUT /api/v1/flash/manifests/{mothership_id}/resources/{resource_name}` - Update resource +- `DELETE /api/v1/flash/manifests/{mothership_id}/resources/{resource_name}` - Remove resource + +**Base URL**: `https://api.runpod.io` (default, configurable) + +**Authentication**: Bearer token using RUNPOD_API_KEY + +### /manifest Endpoint + +The mothership serves a `/manifest` endpoint for service discovery: + +**Endpoint**: `GET /manifest` + +**Response**: +```json +{ + "manifest": { + "gpu_worker": "https://gpu-worker.api.runpod.ai", + "cpu_worker": "https://cpu-worker.api.runpod.ai" + } +} +``` + +**Behavior**: +- Queries ResourceManager for all deployed resources +- Returns partial results during provisioning (gradual population) +- Returns empty manifest if no resources deployed yet +- Graceful error handling - returns manifest and error field on failure + +### Idempotency + +Auto-provisioning is idempotent - running the mothership multiple times: + +**First Boot**: +- All resources in local manifest are NEW +- All deployed to State Manager +- Directory populated + +**Second Boot (unchanged manifest)**: +- All resources have matching config hashes +- All UNCHANGED - none deployed again +- Directory reused + +**Third Boot (with changes)**: +- Changed resources updated +- New resources deployed +- Removed resources undeployed +- Unchanged resources skipped +- Efficient - only changes applied + +This ensures: +- No duplicate resource deployments +- Automatic cleanup of removed resources +- Zero downtime provisioning +- Efficient use of cloud resources + +### Example Workflow + +```python +# main.py - Mothership application +from tetra_rp import LoadBalancerSlsResource, remote +from fastapi import FastAPI + +# Create mothership +mothership = LoadBalancerSlsResource( + name="mothership", + imageName="my-mothership:latest" +) + +# Deploy mothership (auto-provisioning triggered) +# await mothership.deploy() + +# FastAPI app +app = FastAPI() + +@app.get("/ping") +def ping(): + return {"status": "healthy"} + +@app.get("/manifest") +async def get_manifest(): + """Auto-generated endpoint via lifespan hook""" + # Returns directory of deployed children + return {"manifest": {...}} +``` + +### Monitoring Provisioning + +Check mothership logs for provisioning activity: + +``` +Mothership detected, initiating auto-provisioning +Mothership URL: https://mothership-123.api.runpod.ai +Reconciliation complete: 2 new, 1 changed, 1 removed, 3 unchanged +Deployed new resource: gpu_worker +Updated resource: cpu_worker +Deleted removed resource: old_worker +Provisioning complete +``` + +### Error Handling + +Provisioning errors don't block mothership startup: + +``` +Failed to deploy gpu_worker: RuntimeError: GPU allocation failed +(State Manager updated with status: failed) + +Failed to update cpu_worker: ConnectionError: State Manager unavailable +(Continues provisioning other resources) +``` + +The mothership continues serving traffic even if some child deployments fail. + +### Architecture + +```mermaid +graph TD + A["Mothership Boot"] --> B["Lifespan Hook"] + B --> C["Check FLASH_IS_MOTHERSHIP"] + C -->|"true"| D["Spawn Background Task"] + C -->|"false"| Z["Skip provisioning"] + + D --> E["Load Local Manifest"] + E --> F["Query State Manager"] + F --> G["Reconcile Manifests"] + + G --> H["New Resources"] + G --> I["Changed Resources"] + G --> J["Removed Resources"] + G --> K["Unchanged Resources"] + + H --> L["ResourceManager.deploy()"] + I --> L + J --> M["ResourceManager.undeploy()"] + K --> N["Skip (idempotent)"] + + L --> O["Update State Manager"] + M --> P["Remove from State Manager"] + + O --> Q["Directory Updated"] + P --> Q + N --> Q + + Q --> R["/manifest Endpoint Returns
Updated Directory"] + + style A fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style B fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff + style D fill:#f57c00,stroke:#e65100,stroke-width:3px,color:#fff + style G fill:#d32f2f,stroke:#b71c1c,stroke-width:3px,color:#fff + style Q fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style R fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff +``` + +### Implementation Files + +- **StateManagerClient**: `src/tetra_rp/runtime/state_manager_client.py` +- **MothershipProvisioner**: `src/tetra_rp/runtime/mothership_provisioner.py` +- **LB Handler Generator**: `src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py` +- **LoadBalancerSlsResource**: `src/tetra_rp/core/resources/load_balancer_sls_resource.py` + ## Conclusion Cross-endpoint routing provides: @@ -939,5 +1199,6 @@ Cross-endpoint routing provides: - **Resilience**: Graceful fallback to local execution if directory unavailable - **Simplicity**: No changes to function code or signatures - **Debuggability**: Clear error messages and logging for troubleshooting +- **Automation**: Mothership auto-provisioning eliminates manual resource deployment -The architecture prioritizes clarity and maintainability while enabling distributed serverless applications. +The architecture prioritizes clarity and maintainability while enabling distributed serverless applications with automated deployment orchestration. diff --git a/pyproject.toml b/pyproject.toml index ee0b3f4e..a81a0656 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "typer>=0.12.0", "questionary>=2.0.0", "pathspec>=0.11.0", + "tomli>=2.0.0; python_version < '3.11'", ] [dependency-groups] diff --git a/src/tetra_rp/cli/commands/build.py b/src/tetra_rp/cli/commands/build.py index b9b4179d..01d30348 100644 --- a/src/tetra_rp/cli/commands/build.py +++ b/src/tetra_rp/cli/commands/build.py @@ -1,6 +1,7 @@ """Flash build command - Package Flash applications for deployment.""" import ast +import importlib.util import json import logging import re @@ -9,6 +10,7 @@ import sys import tarfile from pathlib import Path +from typing import Optional import typer from rich.console import Console @@ -16,6 +18,11 @@ from rich.progress import Progress, SpinnerColumn, TextColumn from rich.table import Table +try: + import tomllib # Python 3.11+ +except ImportError: + import tomli as tomllib # Python 3.9-3.10 + from ..utils.ignore import get_file_tree, load_ignore_patterns from .build_utils.handler_generator import HandlerGenerator from .build_utils.lb_handler_generator import LBHandlerGenerator @@ -52,6 +59,132 @@ PIP_MODULE = "pip" +def _find_local_tetra_rp() -> Optional[Path]: + """Find local tetra_rp source directory if available. + + Returns: + Path to tetra_rp package directory, or None if not found or installed from PyPI + """ + try: + spec = importlib.util.find_spec("tetra_rp") + + if not spec or not spec.origin: + return None + + # Get package directory (spec.origin is __init__.py path) + pkg_dir = Path(spec.origin).parent + + # Skip if installed in site-packages (PyPI install) + if "site-packages" in str(pkg_dir): + return None + + # Must be development install + return pkg_dir + + except Exception: + return None + + +def _bundle_local_tetra_rp(build_dir: Path) -> bool: + """Copy local tetra_rp source into build directory. + + Args: + build_dir: Target build directory + + Returns: + True if bundled successfully, False otherwise + """ + tetra_pkg = _find_local_tetra_rp() + + if not tetra_pkg: + console.print( + "[yellow]⚠ Local tetra_rp not found or using PyPI install[/yellow]" + ) + return False + + # Copy tetra_rp to build + dest = build_dir / "tetra_rp" + if dest.exists(): + shutil.rmtree(dest) + + shutil.copytree( + tetra_pkg, + dest, + ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".pytest_cache"), + ) + + console.print(f"[cyan]✓ Bundled local tetra_rp from {tetra_pkg}[/cyan]") + return True + + +def _extract_tetra_rp_dependencies(tetra_pkg_dir: Path) -> list[str]: + """Extract runtime dependencies from tetra_rp's pyproject.toml. + + When bundling local tetra_rp source, we need to also install its dependencies + so they're available in the build environment. + + Args: + tetra_pkg_dir: Path to tetra_rp package directory (src/tetra_rp) + + Returns: + List of dependency strings, empty list if parsing fails + """ + try: + # Navigate from tetra_rp package to project root + # tetra_pkg_dir is src/tetra_rp, need to go up 2 levels to reach project root + project_root = tetra_pkg_dir.parent.parent + pyproject_path = project_root / "pyproject.toml" + + if not pyproject_path.exists(): + console.print( + "[yellow]⚠ tetra_rp pyproject.toml not found, " + "dependencies may be missing[/yellow]" + ) + return [] + + # Parse TOML + with open(pyproject_path, "rb") as f: + data = tomllib.load(f) + + # Extract dependencies from [project.dependencies] + dependencies = data.get("project", {}).get("dependencies", []) + + if dependencies: + console.print( + f"[dim]Found {len(dependencies)} tetra_rp dependencies to install[/dim]" + ) + + return dependencies + + except Exception as e: + console.print(f"[yellow]⚠ Failed to parse tetra_rp dependencies: {e}[/yellow]") + return [] + + +def _remove_tetra_from_requirements(build_dir: Path) -> None: + """Remove tetra_rp from requirements.txt and clean up dist-info since we bundled source.""" + req_file = build_dir / "requirements.txt" + + if not req_file.exists(): + return + + lines = req_file.read_text().splitlines() + filtered = [ + line + for line in lines + if not line.strip().startswith("tetra_rp") + and not line.strip().startswith("tetra-rp") + ] + + req_file.write_text("\n".join(filtered) + "\n") + + # Remove tetra_rp dist-info directory to avoid conflicts with bundled source + # dist-info is created by pip install and can confuse Python's import system + for dist_info in build_dir.glob("tetra_rp-*.dist-info"): + if dist_info.is_dir(): + shutil.rmtree(dist_info) + + def build_command( no_deps: bool = typer.Option( False, "--no-deps", help="Skip transitive dependencies during pip install" @@ -67,6 +200,11 @@ def build_command( "--exclude", help="Comma-separated packages to exclude (e.g., 'torch,torchvision')", ), + use_local_tetra: bool = typer.Option( + False, + "--use-local-tetra", + help="Bundle local tetra_rp source instead of PyPI version (for development/testing)", + ), ): """ Build Flash application for deployment. @@ -90,6 +228,9 @@ def build_command( console.print("Run [bold]flash init[/bold] to create a Flash project") raise typer.Exit(1) + # Create build directory first to ensure clean state before collecting files + build_dir = create_build_directory(project_dir, app_name) + # Parse exclusions excluded_packages = [] if exclude: @@ -121,9 +262,8 @@ def build_command( ) progress.stop_task(collect_task) - # Create build directory + # Note: build directory already created before progress tracking build_task = progress.add_task("Creating build directory...") - build_dir = create_build_directory(project_dir, app_name) progress.update( build_task, description="[green]✓ Created .flash/.build/", @@ -156,15 +296,16 @@ def build_command( handler_paths = [] # Separate resources by type + # Use flag determined by isinstance() at scan time lb_resources = { name: data for name, data in manifest.get("resources", {}).items() - if data.get("resource_type") == "LoadBalancerSlsResource" + if data.get("is_load_balanced", False) } qb_resources = { name: data for name, data in manifest.get("resources", {}).items() - if data.get("resource_type") != "LoadBalancerSlsResource" + if not data.get("is_load_balanced", False) } # Generate LB handlers @@ -219,10 +360,21 @@ def build_command( logger.exception("Build failed") raise typer.Exit(1) + # Extract tetra_rp dependencies if bundling local version + tetra_deps = [] + if use_local_tetra: + tetra_pkg = _find_local_tetra_rp() + if tetra_pkg: + tetra_deps = _extract_tetra_rp_dependencies(tetra_pkg) + # Install dependencies deps_task = progress.add_task("Installing dependencies...") requirements = collect_requirements(project_dir, build_dir) + # Add tetra_rp dependencies if bundling local version + # This ensures all tetra_rp runtime dependencies are available in the build + requirements.extend(tetra_deps) + # Filter out excluded packages if excluded_packages: original_count = len(requirements) @@ -280,6 +432,22 @@ def build_command( progress.stop_task(deps_task) + # Bundle local tetra_rp if requested + if use_local_tetra: + tetra_task = progress.add_task("Bundling local tetra_rp...") + if _bundle_local_tetra_rp(build_dir): + _remove_tetra_from_requirements(build_dir) + progress.update( + tetra_task, + description="[green]✓ Bundled local tetra_rp", + ) + else: + progress.update( + tetra_task, + description="[yellow]⚠ Using PyPI tetra_rp", + ) + progress.stop_task(tetra_task) + # Clean up Python bytecode before archiving cleanup_python_bytecode(build_dir) diff --git a/src/tetra_rp/cli/commands/build_utils/handler_generator.py b/src/tetra_rp/cli/commands/build_utils/handler_generator.py index ef6ec318..8dace74f 100644 --- a/src/tetra_rp/cli/commands/build_utils/handler_generator.py +++ b/src/tetra_rp/cli/commands/build_utils/handler_generator.py @@ -57,12 +57,13 @@ def generate_handlers(self) -> List[Path]: for resource_name, resource_data in resources.items(): # Skip load-balanced resources (handled by LBHandlerGenerator) - resource_type = ( - resource_data.resource_type - if hasattr(resource_data, "resource_type") - else resource_data.get("resource_type") + # Use flag determined by isinstance() at scan time + is_load_balanced = ( + resource_data.is_load_balanced + if hasattr(resource_data, "is_load_balanced") + else resource_data.get("is_load_balanced", False) ) - if resource_type == "LoadBalancerSlsResource": + if is_load_balanced: continue handler_path = self._generate_handler(resource_name, resource_data) diff --git a/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py b/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py index f7c679c3..b639773e 100644 --- a/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py +++ b/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py @@ -21,9 +21,17 @@ - Real-time communication patterns """ +import asyncio +import logging +from contextlib import asynccontextmanager +from pathlib import Path +from typing import Optional + from fastapi import FastAPI, Request from tetra_rp.runtime.lb_handler import create_lb_handler +logger = logging.getLogger(__name__) + # Import all functions/classes that belong to this resource {imports} @@ -32,11 +40,70 @@ {registry} }} -# Create FastAPI app with routes +# Module-level state for /manifest endpoint +_state_client: Optional[StateManagerClient] = None + + +# Lifespan context manager for startup/shutdown +@asynccontextmanager +async def lifespan(app: FastAPI): + """Handle application startup and shutdown.""" + # Startup + logger.info("Starting {resource_name} endpoint") + + # Check if this is the mothership and initiate provisioning + try: + from tetra_rp.runtime.mothership_provisioner import ( + is_mothership, + provision_children, + get_mothership_url, + ) + from tetra_rp.runtime.state_manager_client import StateManagerClient + + if is_mothership(): + logger.info("=" * 60) + logger.info("Mothership detected - Starting auto-provisioning") + logger.info("Test phase: Deploying child endpoints with 'tmp-' prefix") + logger.info("=" * 60) + try: + mothership_url = get_mothership_url() + logger.info(f"Mothership URL: {{mothership_url}}") + + # Initialize State Manager client and store in module-level state + state_client = StateManagerClient() + global _state_client + _state_client = state_client + + # Spawn background provisioning task (non-blocking) + manifest_path = Path(__file__).parent / "flash_manifest.json" + task = asyncio.create_task( + provision_children(manifest_path, mothership_url, state_client) + ) + # Add error callback to catch and log background task exceptions + task.add_done_callback( + lambda t: logger.error(f"Background provisioning failed: {{t.exception()}}") + if t.exception() + else None + ) + + except Exception as e: + logger.error(f"Failed to start mothership provisioning: {{e}}") + # Don't fail startup - continue serving traffic + + except ImportError: + logger.debug("Mothership provisioning modules not available") + + yield + + # Shutdown + logger.info("Shutting down {resource_name} endpoint") + + +# Create FastAPI app with routes and lifespan # Note: include_execute={include_execute} for this endpoint type # - LiveLoadBalancer (local): include_execute=True for /execute endpoint # - LoadBalancerSlsResource (deployed): include_execute=False (security) -app = create_lb_handler(ROUTE_REGISTRY, include_execute={include_execute}) +app = create_lb_handler(ROUTE_REGISTRY, include_execute={include_execute}, lifespan=lifespan) # Health check endpoint (required for RunPod load-balancer endpoints) @@ -50,6 +117,58 @@ def ping(): return {{"status": "healthy"}} +# Manifest endpoint for service discovery +@app.get("/manifest") +async def manifest(): + """Return complete authoritative manifest for service discovery. + + Fetches the full manifest from State Manager, allowing child endpoints + to synchronize their configuration. + + Returns: + dict: Complete manifest with version, generated_at, project_name, + function_registry, resources, and routes + """ + try: + import os + from tetra_rp.runtime.mothership_provisioner import is_mothership + + # Only mothership serves manifest + if not is_mothership(): + return {{"error": "Only mothership serves manifest"}}, 403 + + # Check state client initialized + global _state_client + if _state_client is None: + return {{"error": "State Manager not initialized"}}, 500 + + # Get mothership ID + mothership_id = os.getenv("RUNPOD_ENDPOINT_ID") + if not mothership_id: + return {{"error": "RUNPOD_ENDPOINT_ID not set"}}, 500 + + # Fetch persisted manifest from State Manager (single source of truth) + persisted_manifest = await _state_client.get_persisted_manifest(mothership_id) + + # First boot: no manifest yet, return minimal structure + if persisted_manifest is None: + return {{ + "version": "1.0", + "generated_at": "", + "project_name": "", + "function_registry": {{}}, + "resources": {{}}, + "routes": {{}} + }} + + # Return complete manifest + return persisted_manifest + + except Exception as e: + logger.error(f"Failed to get manifest: {{e}}") + return {{"error": str(e)}}, 500 + + if __name__ == "__main__": import uvicorn # Local development server for testing @@ -77,12 +196,13 @@ def generate_handlers(self) -> List[Path]: for resource_name, resource_data in resources.items(): # Generate for both LiveLoadBalancer (local dev) and LoadBalancerSlsResource (deployed) - resource_type = ( - resource_data.resource_type - if hasattr(resource_data, "resource_type") - else resource_data.get("resource_type") + # Use flag determined by isinstance() at scan time + is_load_balanced = ( + resource_data.is_load_balanced + if hasattr(resource_data, "is_load_balanced") + else resource_data.get("is_load_balanced", False) ) - if resource_type not in ["LoadBalancerSlsResource", "LiveLoadBalancer"]: + if not is_load_balanced: continue handler_path = self._generate_handler(resource_name, resource_data) @@ -104,12 +224,12 @@ def _generate_handler(self, resource_name: str, resource_data: Any) -> Path: # Determine if /execute endpoint should be included # LiveLoadBalancer (local dev) includes /execute, deployed LoadBalancerSlsResource does not - resource_type = ( - resource_data.resource_type - if hasattr(resource_data, "resource_type") - else resource_data.get("resource_type", "LoadBalancerSlsResource") + # Use flag determined by isinstance() at scan time + include_execute = ( + resource_data.is_live_resource + if hasattr(resource_data, "is_live_resource") + else resource_data.get("is_live_resource", False) ) - include_execute = resource_type == "LiveLoadBalancer" # Get functions from resource (handle both dict and ResourceConfig) functions = ( @@ -143,13 +263,19 @@ def _generate_handler(self, resource_name: str, resource_data: Any) -> Path: def _generate_imports(self, functions: List[Any]) -> str: """Generate import statements for functions. + Uses importlib to handle module paths with any characters, + including numeric prefixes that aren't valid Python identifiers. + Args: functions: List of function metadata (dicts or FunctionMetadata objects) Returns: Import statements as string """ - imports = [] + if not functions: + return "# No functions to import" + + imports = ["import importlib"] for func in functions: # Handle both dict and FunctionMetadata @@ -157,9 +283,10 @@ def _generate_imports(self, functions: List[Any]) -> str: name = func.name if hasattr(func, "name") else func.get("name") if module and name: - imports.append(f"from {module} import {name}") + # Use importlib to handle module names with invalid identifiers + imports.append(f"{name} = importlib.import_module('{module}').{name}") - return "\n".join(imports) if imports else "# No functions to import" + return "\n".join(imports) def _generate_route_registry(self, functions: List[Any]) -> str: """Generate route registry for FastAPI app. diff --git a/src/tetra_rp/cli/commands/build_utils/manifest.py b/src/tetra_rp/cli/commands/build_utils/manifest.py index bd0261cd..02edc3b7 100644 --- a/src/tetra_rp/cli/commands/build_utils/manifest.py +++ b/src/tetra_rp/cli/commands/build_utils/manifest.py @@ -21,6 +21,9 @@ class ManifestFunction: is_class: bool http_method: Optional[str] = None # HTTP method for LB endpoints (GET, POST, etc.) http_path: Optional[str] = None # HTTP path for LB endpoints (/api/process) + is_load_balanced: bool = False # Determined by isinstance() at scan time + is_live_resource: bool = False # LiveLoadBalancer vs LoadBalancerSlsResource + config_variable: Optional[str] = None # Variable name like "gpu_config" @dataclass @@ -30,6 +33,9 @@ class ManifestResource: resource_type: str handler_file: str functions: List[ManifestFunction] + is_load_balanced: bool = False # Determined by isinstance() at scan time + is_live_resource: bool = False # LiveLoadBalancer vs LoadBalancerSlsResource + config_variable: Optional[str] = None # Variable name for test-mothership class ManifestBuilder: @@ -66,12 +72,12 @@ def build(self) -> Dict[str, Any]: functions[0].resource_type if functions else "LiveServerless" ) + # Extract flags from first function (determined by isinstance() at scan time) + is_load_balanced = functions[0].is_load_balanced if functions else False + is_live_resource = functions[0].is_live_resource if functions else False + # Validate and collect routing for LB endpoints resource_routes = {} - is_load_balanced = resource_type in [ - "LoadBalancerSlsResource", - "LiveLoadBalancer", - ] if is_load_balanced: for f in functions: if not f.http_method or not f.http_path: @@ -98,12 +104,18 @@ def build(self) -> Dict[str, Any]: f"Reserved paths: {', '.join(RESERVED_PATHS)}" ) + # Extract config_variable from first function (all functions in same resource share same config) + config_variable = functions[0].config_variable if functions else None + functions_list = [ { "name": f.function_name, "module": f.module_path, "is_async": f.is_async, "is_class": f.is_class, + "is_load_balanced": f.is_load_balanced, + "is_live_resource": f.is_live_resource, + "config_variable": f.config_variable, **( {"http_method": f.http_method, "http_path": f.http_path} if is_load_balanced @@ -117,6 +129,9 @@ def build(self) -> Dict[str, Any]: "resource_type": resource_type, "handler_file": handler_file, "functions": functions_list, + "is_load_balanced": is_load_balanced, + "is_live_resource": is_live_resource, + "config_variable": config_variable, } # Store routes for LB endpoints diff --git a/src/tetra_rp/cli/commands/build_utils/scanner.py b/src/tetra_rp/cli/commands/build_utils/scanner.py index 92f80fc2..1ea655fd 100644 --- a/src/tetra_rp/cli/commands/build_utils/scanner.py +++ b/src/tetra_rp/cli/commands/build_utils/scanner.py @@ -24,6 +24,11 @@ class RemoteFunctionMetadata: file_path: Path http_method: Optional[str] = None # HTTP method for LB endpoints: GET, POST, etc. http_path: Optional[str] = None # HTTP path for LB endpoints: /api/process + is_load_balanced: bool = False # LoadBalancerSlsResource or LiveLoadBalancer + is_live_resource: bool = ( + False # LiveLoadBalancer (vs deployed LoadBalancerSlsResource) + ) + config_variable: Optional[str] = None # Variable name like "gpu_config" class RemoteDecoratorScanner: @@ -34,6 +39,8 @@ def __init__(self, project_dir: Path): self.py_files: List[Path] = [] self.resource_configs: Dict[str, str] = {} # name -> name self.resource_types: Dict[str, str] = {} # name -> type + self.resource_flags: Dict[str, Dict[str, bool]] = {} # name -> {flag: bool} + self.resource_variables: Dict[str, str] = {} # name -> variable_name def discover_remote_functions(self) -> List[RemoteFunctionMetadata]: """Discover all @remote decorated functions and classes.""" @@ -83,7 +90,11 @@ def discover_remote_functions(self) -> List[RemoteFunctionMetadata]: return functions def _extract_resource_configs(self, tree: ast.AST, py_file: Path) -> None: - """Extract resource config variable assignments.""" + """Extract resource config variable assignments and determine type flags. + + This method extracts resource configurations and determines is_load_balanced + and is_live_resource flags using string-based type matching. + """ module_path = self._get_module_path(py_file) for node in ast.walk(tree): @@ -94,7 +105,7 @@ def _extract_resource_configs(self, tree: ast.AST, py_file: Path) -> None: variable_name = target.id config_type = self._get_call_type(node.value) - # Accept any class that looks like a resource config (ServerlessResource) + # Accept any class that looks like a resource config (DeployableResource) if config_type and self._is_resource_config_type(config_type): # Extract the resource's name parameter (the actual identifier) # If extraction fails, fall back to variable name @@ -106,10 +117,34 @@ def _extract_resource_configs(self, tree: ast.AST, py_file: Path) -> None: self.resource_configs[resource_name] = resource_name self.resource_types[resource_name] = config_type + # Store variable name for test-mothership config discovery + self.resource_variables[resource_name] = variable_name + # Also store variable name mapping for local lookups in same module var_key = f"{module_path}:{variable_name}" self.resource_configs[var_key] = resource_name self.resource_types[var_key] = config_type + self.resource_variables[var_key] = variable_name + + # Determine boolean flags using string-based type checking + # This is determined by isinstance() at scan time in production, + # but we use string matching for reliability + is_load_balanced = config_type in [ + "LoadBalancerSlsResource", + "LiveLoadBalancer", + ] + is_live_resource = config_type == "LiveLoadBalancer" + + # Store flags for this resource + self.resource_flags[resource_name] = { + "is_load_balanced": is_load_balanced, + "is_live_resource": is_live_resource, + } + # Also store for variable key + self.resource_flags[var_key] = { + "is_load_balanced": is_load_balanced, + "is_live_resource": is_live_resource, + } def _extract_remote_functions( self, tree: ast.AST, py_file: Path @@ -141,6 +176,12 @@ def _extract_remote_functions( remote_decorator ) + # Get flags for this resource + flags = self.resource_flags.get( + resource_config_name, + {"is_load_balanced": False, "is_live_resource": False}, + ) + metadata = RemoteFunctionMetadata( function_name=node.name, module_path=module_path, @@ -151,6 +192,11 @@ def _extract_remote_functions( file_path=py_file, http_method=http_method, http_path=http_path, + is_load_balanced=flags["is_load_balanced"], + is_live_resource=flags["is_live_resource"], + config_variable=self.resource_variables.get( + resource_config_name + ), ) functions.append(metadata) diff --git a/src/tetra_rp/cli/commands/run.py b/src/tetra_rp/cli/commands/run.py index 8508e08b..ba79d5fa 100644 --- a/src/tetra_rp/cli/commands/run.py +++ b/src/tetra_rp/cli/commands/run.py @@ -206,6 +206,16 @@ def _discover_resources(entry_point: str): try: discovery = ResourceDiscovery(entry_point, max_depth=2) resources = discovery.discover() + + # Debug: Log what was discovered + if resources: + console.print(f"\n[dim]Discovered {len(resources)} resource(s):[/dim]") + for res in resources: + res_name = getattr(res, "name", "Unknown") + res_type = res.__class__.__name__ + console.print(f" [dim]• {res_name} ({res_type})[/dim]") + console.print() + return resources except Exception as e: console.print(f"[yellow]Warning:[/yellow] Resource discovery failed: {e}") diff --git a/src/tetra_rp/cli/commands/test_mothership.py b/src/tetra_rp/cli/commands/test_mothership.py new file mode 100644 index 00000000..ea239e78 --- /dev/null +++ b/src/tetra_rp/cli/commands/test_mothership.py @@ -0,0 +1,457 @@ +"""Flash test-mothership command - Test mothership boot locally with Docker.""" + +import logging +import shutil +import subprocess +import sys +import time +from pathlib import Path +from typing import Optional + +import typer +from rich.console import Console +from rich.panel import Panel + +logger = logging.getLogger(__name__) +console = Console() + + +def _clear_resource_cache() -> None: + """Clear ResourceManager cache for clean test environment. + + Test-mothership deploys temporary endpoints that should not persist + between test runs. Clearing the cache prevents: + - Stale resources from previous tests being redeployed + - Name conflicts between old and new test resources + - Confusion from endpoints that no longer exist in the codebase + """ + cache_file = Path.home() / ".runpod" / "resources.pkl" + if cache_file.exists(): + try: + cache_file.unlink() + console.print( + "[dim]Cleared resource cache for clean test environment[/dim]" + ) + logger.debug(f"Removed cache file: {cache_file}") + except Exception as e: + console.print(f"[yellow]Warning: Could not clear cache: {e}[/yellow]") + logger.warning(f"Failed to remove cache file {cache_file}: {e}") + + +def test_mothership_command( + image: str = typer.Option( + "runpod/tetra-rp-lb-cpu:local", + "--image", + help="Docker image to use for testing", + ), + port: int = typer.Option(8000, "--port", help="Local port to expose"), + endpoint_id: Optional[str] = typer.Option( + None, "--endpoint-id", help="RunPod endpoint ID (auto-generated if omitted)" + ), + build_dir: str = typer.Option( + ".flash/.build", "--build-dir", help="Path to build directory" + ), + no_build: bool = typer.Option( + False, "--no-build", help="Skip running flash build first" + ), +): + """ + Test mothership boot locally with Docker. + + Runs the application in a Docker container with mothership provisioning enabled. + This simulates the mothership deployment process, including auto-provisioning of + child resources to RunPod. On shutdown (Ctrl+C or docker stop), automatically + cleans up all deployed endpoints. + + Examples: + flash test-mothership # Default setup + flash test-mothership --port 9000 # Custom port + flash test-mothership --image custom:latest # Custom Docker image + flash test-mothership --no-build # Skip flash build step + """ + try: + # Verify prerequisites + _verify_prerequisites() + + # Clear resource cache to prevent stale entries in test mode + _clear_resource_cache() + + # Build if needed + if not no_build: + _run_flash_build() + + # Generate endpoint ID if not provided + if not endpoint_id: + endpoint_id = f"test-mothership-{int(time.time())}" + + # Create entrypoint script for cleanup on shutdown + _create_entrypoint_script(build_dir) + + # Display configuration + _display_test_objectives() + _display_config(build_dir, image, port, endpoint_id) + + # Build Docker command + docker_cmd = _build_docker_command(image, port, endpoint_id, build_dir) + + # Run Docker container + _run_docker_container(docker_cmd, port) + + except typer.Exit: + raise + except Exception as e: + console.print(f"[red]Error:[/red] {e}") + logger.exception("Unexpected error in test_mothership_command") + raise typer.Exit(1) + + +def _verify_prerequisites() -> None: + """Verify that Docker and RUNPOD_API_KEY are available.""" + # Check Docker + result = shutil.which("docker") + if not result: + console.print("[red]Error:[/red] Docker is not installed or not in PATH") + console.print( + "Install Docker from: https://www.docker.com/products/docker-desktop" + ) + raise typer.Exit(1) + + # Check Docker daemon + try: + subprocess.run( + ["docker", "ps"], + capture_output=True, + check=True, + timeout=5, + ) + except ( + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + FileNotFoundError, + ): + console.print("[red]Error:[/red] Docker daemon is not running") + console.print("Start Docker and try again") + raise typer.Exit(1) + + # Check RUNPOD_API_KEY + import os + + if not os.getenv("RUNPOD_API_KEY"): + console.print("[red]Error:[/red] RUNPOD_API_KEY environment variable not set") + console.print("Set it with: export RUNPOD_API_KEY=your-api-key") + raise typer.Exit(1) + + +def _run_flash_build() -> None: + """Run flash build command.""" + console.print("[cyan]Running flash build...[/cyan]") + result = subprocess.run( + ["flash", "build", "--keep-build", "--use-local-tetra"], + capture_output=False, + ) + if result.returncode != 0: + console.print("[red]Error:[/red] flash build failed") + raise typer.Exit(1) + + +def _get_manifest_provisioning_code() -> str: + """Generate Python code to provision resources from flash_manifest.json. + + Uses the manifest as a guide to discover which modules contain resource configs. + Imports the actual resource configs from source (endpoint files) to get full + configuration (workers, GPUs, etc.). This ensures test-mothership provisions + exactly what was built, without discovering skeleton templates. + + Returns: + Python code as a string to be executed + """ + return """ +import asyncio +import importlib +import json +import logging +import os +import sys +from pathlib import Path +from tetra_rp.core.deployment import DeploymentOrchestrator + +logger = logging.getLogger(__name__) + +# Configure logging to match the rest of the system +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s | %(levelname)-5s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) + +async def provision_from_manifest(): + manifest_path = Path("flash_manifest.json") + if not manifest_path.exists(): + print("[dim]No flash_manifest.json found, skipping manifest-based provisioning[/dim]") + return + + try: + with open(manifest_path) as f: + manifest = json.load(f) + except Exception as e: + logger.error(f"Error loading manifest: {e}") + return + + # Set test-mothership mode for resource naming + os.environ["FLASH_IS_TEST_MOTHERSHIP"] = "true" + + resources = [] + for resource_name, resource_data in manifest.get("resources", {}).items(): + try: + # Get list of modules that contain this resource's functions + functions = resource_data.get("functions", []) + if not functions: + logger.warning(f"No functions found for resource {resource_name}") + continue + + # Import the first function's module to get access to the config + first_func = functions[0] + module_name = first_func.get("module") + if not module_name: + logger.warning(f"No module found for resource {resource_name}") + continue + + # Import the module and look for resource config variable + try: + module = importlib.import_module(module_name) + + config = None + + # Try config_variable from manifest first (most reliable) + config_variable = resource_data.get("config_variable") + if config_variable and hasattr(module, config_variable): + config = getattr(module, config_variable) + logger.info(f"Loaded resource config from {module_name}: {config.name} (variable: {config_variable})") + else: + # Fallback to old search logic for backward compatibility + config_names = [ + "gpu_config", "cpu_config", + "resource_config", "config", + f"{resource_name.lower()}_config", + ] + + for config_name in config_names: + if hasattr(module, config_name): + config = getattr(module, config_name) + break + + if config: + logger.info(f"Loaded resource config from {module_name}: {config.name}") + else: + logger.warning(f"No config variable found in {module_name} for {resource_name}") + + if config: + # Apply test-mothership naming convention + if not resource_name.startswith("tmp-"): + config.name = f"tmp-{resource_name}" + else: + config.name = resource_name + + resources.append(config) + + except Exception as e: + logger.warning(f"Failed to import resource config from {module_name}: {e}") + + except Exception as e: + logger.error(f"Failed to process resource {resource_name}: {e}") + + if resources: + try: + logger.info(f"Provisioning {len(resources)} resource(s)...") + orchestrator = DeploymentOrchestrator() + await orchestrator.deploy_all(resources, show_progress=True) + except Exception as e: + logger.warning(f"Provisioning error: {e}") + else: + logger.warning("No resources loaded from manifest") + +asyncio.run(provision_from_manifest()) +""" + + +def _create_entrypoint_script(build_dir: str) -> None: + """Create entrypoint.sh script for Docker container. + + This script handles signal trapping and cleanup on shutdown. + It runs manifest-based provisioning then flash run (without --auto-provision + to avoid duplicate discovery from bundled dependencies). + """ + build_path = Path(build_dir) + + # Ensure build directory exists + if not build_path.exists(): + console.print( + f"[yellow]Warning:[/yellow] Build directory {build_dir} does not exist" + ) + return + + script_path = build_path / "entrypoint.sh" + provisioning_script_path = build_path / "provision_from_manifest.py" + + # Write provisioning script to file + provisioning_code = _get_manifest_provisioning_code() + provisioning_script_path.write_text(provisioning_code) + + script_content = """#!/bin/bash +set -e + +# Ensure bundled dependencies are available to Python +# /workspace contains all the pip-installed packages (.so files, pure Python modules, etc) +export PYTHONPATH="/workspace:${PYTHONPATH}" + +# Signal test-mothership provisioning context for resource naming +export FLASH_IS_TEST_MOTHERSHIP="true" + +cleanup() { + echo "" + echo "==========================================" + echo "Shutting down test-mothership..." + echo "Cleaning up all temporary endpoints..." + echo "==========================================" + python -m tetra_rp.cli.main undeploy --all --force || true + echo "Cleanup complete" + exit 0 +} + +trap cleanup SIGTERM SIGINT + +echo "==========================================" +echo "Starting mothership test environment" +echo "Phase 1: Mothership container startup" +echo "==========================================" + +# Provision resources from manifest before starting server +# This uses the same method as production mothership, avoiding +# false discovery from bundled skeleton templates +python3 provision_from_manifest.py + +# Start server without --auto-provision to avoid re-discovering resources +python -m tetra_rp.cli.main run --host 0.0.0.0 --port 8000 & +PID=$! + +wait $PID +""" + + script_path.write_text(script_content) + script_path.chmod(0o755) + + +def _display_test_objectives() -> None: + """Display what test-mothership tests and important warnings.""" + objectives_text = """[bold cyan]What this tests:[/bold cyan] +• Mothership container deployment +• Child endpoint auto-provisioning +• Manifest generation and updates + +[bold yellow]⚠ Important:[/bold yellow] +• Child endpoints are [bold]temporary[/bold] - prefixed with 'tmp-' +• Child endpoints are [bold]non-functional[/bold] for actual workloads +• All child endpoints will be [bold]automatically cleaned up[/bold] on shutdown + +[dim]These are test deployments only. Use 'flash deploy' for production.[/dim]""" + + console.print( + Panel( + objectives_text, + title="Test-Mothership Overview", + border_style="cyan", + ) + ) + console.print() + + +def _display_config(build_dir: str, image: str, port: int, endpoint_id: str) -> None: + """Display test configuration.""" + config_text = f"""[bold]Build directory:[/bold] {build_dir} +[bold]Command:[/bold] flash run +[bold]Docker image:[/bold] {image} +[bold]Endpoint ID:[/bold] {endpoint_id} +[bold]Port:[/bold] http://localhost:{port}""" + + console.print(Panel(config_text, title="🚀 Starting mothership test container")) + + +def _build_docker_command( + image: str, port: int, endpoint_id: str, build_dir: str +) -> list: + """Build the docker run command.""" + import os + + build_path = Path(build_dir).resolve() + + cmd = [ + "docker", + "run", + "--platform", + "linux/amd64", + "--rm", + ] + + # Add interactive flags only if running in a TTY environment + if sys.stdin.isatty() and sys.stdout.isatty(): + cmd.extend(["-it"]) + + cmd.extend( + [ + "-e", + "FLASH_IS_MOTHERSHIP=true", + "-e", + "FLASH_IS_TEST_MOTHERSHIP=true", + "-e", + f"RUNPOD_ENDPOINT_ID={endpoint_id}", + "-e", + f"RUNPOD_API_KEY={os.getenv('RUNPOD_API_KEY')}", + "-e", + "FLASH_MANIFEST_PATH=/workspace/flash_manifest.json", + "-v", + f"{build_path}:/workspace", + "-p", + f"{port}:8000", + "--workdir", + "/workspace", + image, + "/workspace/entrypoint.sh", + ] + ) + + return cmd + + +def _run_docker_container(docker_cmd: list, port: int) -> None: + """Run the Docker container with helpful output.""" + console.print("[cyan]✅ Container started successfully[/cyan]\n") + console.print(f"[dim]Local: http://localhost:{port}[/dim]\n") + console.print("[dim]Verification commands:[/dim]") + console.print(f"[dim] Health: curl http://localhost:{port}/ping[/dim]") + console.print( + f"[dim] Manifest (child endpoints see this):" + f" curl http://localhost:{port}/manifest[/dim]\n" + ) + console.print("[bold]Test phases:[/bold]") + console.print(" [dim]1. Mothership startup and health check[/dim]") + console.print( + " [dim]2. Auto-provisioning child endpoints (prefixed with 'tmp-')[/dim]" + ) + console.print(" [dim]3. Manifest update with child endpoint URLs[/dim]") + console.print() + console.print("[dim]Watch container logs below for provisioning progress...[/dim]") + console.print("[dim]Press Ctrl+C to stop and cleanup all endpoints.\n[/dim]") + + try: + result = subprocess.run(docker_cmd, check=False, capture_output=False) + if result.returncode != 0: + console.print( + "\n[yellow]Container exited with an error.[/yellow] " + "Check the logs above for details. Common issues: missing RUNPOD_API_KEY, " + "port already in use, or Docker daemon not running." + ) + except KeyboardInterrupt: + console.print("\n[yellow]Container stopped[/yellow]") + except Exception as e: + console.print(f"[red]Error running container:[/red] {e}") + raise typer.Exit(1) diff --git a/src/tetra_rp/cli/commands/undeploy.py b/src/tetra_rp/cli/commands/undeploy.py index 7aaa349b..0cc0b165 100644 --- a/src/tetra_rp/cli/commands/undeploy.py +++ b/src/tetra_rp/cli/commands/undeploy.py @@ -225,6 +225,9 @@ def undeploy_command( "--cleanup-stale", help="Remove inactive endpoints from tracking (already deleted externally)", ), + force: bool = typer.Option( + False, "--force", "-f", help="Force operation without confirmation prompts" + ), ): """Undeploy (delete) RunPod serverless endpoints. @@ -239,6 +242,9 @@ def undeploy_command( # Undeploy all endpoints (with confirmation) flash undeploy --all + # Undeploy all endpoints without confirmation + flash undeploy --all --force + # Interactive selection flash undeploy --interactive @@ -271,11 +277,11 @@ def undeploy_command( # Handle different modes if interactive: - _interactive_undeploy(resources) + _interactive_undeploy(resources, skip_confirm=force) elif all: - _undeploy_all(resources) + _undeploy_all(resources, skip_confirm=force) elif name: - _undeploy_by_name(name, resources) + _undeploy_by_name(name, resources, skip_confirm=force) else: console.print( Panel( @@ -291,12 +297,13 @@ def undeploy_command( raise typer.Exit(0) -def _undeploy_by_name(name: str, resources: dict): +def _undeploy_by_name(name: str, resources: dict, skip_confirm: bool = False): """Undeploy endpoints matching the given name. Args: name: Name to search for resources: Dict of all resources + skip_confirm: Skip confirmation prompts """ # Find matching resources matches = [] @@ -333,17 +340,18 @@ def _undeploy_by_name(name: str, resources: dict): console.print("[red]🚨 This action cannot be undone![/red]\n") - try: - confirmed = questionary.confirm( - f"Are you sure you want to delete {len(matches)} endpoint(s)?" - ).ask() - - if not confirmed: - console.print("Undeploy cancelled") + if not skip_confirm: + try: + confirmed = questionary.confirm( + f"Are you sure you want to delete {len(matches)} endpoint(s)?" + ).ask() + + if not confirmed: + console.print("Undeploy cancelled") + raise typer.Exit(0) + except KeyboardInterrupt: + console.print("\nUndeploy cancelled") raise typer.Exit(0) - except KeyboardInterrupt: - console.print("\nUndeploy cancelled") - raise typer.Exit(0) # Delete endpoints manager = _get_resource_manager() @@ -369,11 +377,12 @@ def _undeploy_by_name(name: str, resources: dict): console.print(f" • {result['message']}") -def _undeploy_all(resources: dict): +def _undeploy_all(resources: dict, skip_confirm: bool = False): """Undeploy all endpoints with confirmation. Args: resources: Dict of all resources + skip_confirm: Skip confirmation prompts """ # Show what will be deleted console.print( @@ -391,24 +400,25 @@ def _undeploy_all(resources: dict): console.print("\n[red]🚨 This action cannot be undone![/red]\n") - try: - confirmed = questionary.confirm( - f"Are you sure you want to delete ALL {len(resources)} endpoints?" - ).ask() + if not skip_confirm: + try: + confirmed = questionary.confirm( + f"Are you sure you want to delete ALL {len(resources)} endpoints?" + ).ask() - if not confirmed: - console.print("Undeploy cancelled") - raise typer.Exit(0) + if not confirmed: + console.print("Undeploy cancelled") + raise typer.Exit(0) - # Double confirmation for --all - typed_confirm = questionary.text("Type 'DELETE ALL' to confirm:").ask() + # Double confirmation for --all + typed_confirm = questionary.text("Type 'DELETE ALL' to confirm:").ask() - if typed_confirm != "DELETE ALL": - console.print("Confirmation failed - text does not match") - raise typer.Exit(1) - except KeyboardInterrupt: - console.print("\nUndeploy cancelled") - raise typer.Exit(0) + if typed_confirm != "DELETE ALL": + console.print("Confirmation failed - text does not match") + raise typer.Exit(1) + except KeyboardInterrupt: + console.print("\nUndeploy cancelled") + raise typer.Exit(0) # Delete all endpoints manager = _get_resource_manager() @@ -436,11 +446,12 @@ def _undeploy_all(resources: dict): console.print(f" • {result['message']}") -def _interactive_undeploy(resources: dict): +def _interactive_undeploy(resources: dict, skip_confirm: bool = False): """Interactive checkbox selection for undeploying endpoints. Args: resources: Dict of all resources + skip_confirm: Skip confirmation prompts """ # Create choices for questionary choices = [] @@ -484,13 +495,14 @@ def _interactive_undeploy(resources: dict): console.print("\n[red]🚨 This action cannot be undone![/red]\n") - confirmed = questionary.confirm( - f"Are you sure you want to delete {len(selected)} endpoint(s)?" - ).ask() + if not skip_confirm: + confirmed = questionary.confirm( + f"Are you sure you want to delete {len(selected)} endpoint(s)?" + ).ask() - if not confirmed: - console.print("Undeploy cancelled") - raise typer.Exit(0) + if not confirmed: + console.print("Undeploy cancelled") + raise typer.Exit(0) except KeyboardInterrupt: console.print("\nUndeploy cancelled") raise typer.Exit(0) diff --git a/src/tetra_rp/cli/main.py b/src/tetra_rp/cli/main.py index f1735aba..6ec076ec 100644 --- a/src/tetra_rp/cli/main.py +++ b/src/tetra_rp/cli/main.py @@ -9,6 +9,7 @@ init, run, build, + test_mothership, # resource, deploy, apps, @@ -38,6 +39,7 @@ def get_version() -> str: app.command("init")(init.init_command) app.command("run")(run.run_command) app.command("build")(build.build_command) +app.command("test-mothership")(test_mothership.test_mothership_command) # app.command("report")(resource.report_command) diff --git a/src/tetra_rp/core/discovery.py b/src/tetra_rp/core/discovery.py index 0a18b8e2..8ce4f3e5 100644 --- a/src/tetra_rp/core/discovery.py +++ b/src/tetra_rp/core/discovery.py @@ -58,10 +58,14 @@ def discover(self) -> List[DeployableResource]: else: log.warning(f"Failed to import {self.entry_point}") + log.info(f"[Discovery] After entry point: {len(resources)} resource(s)") + # Recursively scan imported modules (static imports) imported_resources = self._scan_imports(self.entry_point, depth=1) resources.extend(imported_resources) + log.info(f"[Discovery] After static imports: {len(resources)} resource(s)") + # Fallback: Scan project directory for Python files with @remote decorators # This handles dynamic imports (importlib.util) that AST parsing misses if not resources: @@ -70,6 +74,15 @@ def discover(self) -> List[DeployableResource]: ) directory_resources = self._scan_project_directory() resources.extend(directory_resources) + log.info( + f"[Discovery] After directory scan: {len(resources)} resource(s)" + ) + + log.info(f"[Discovery] Total: {len(resources)} resource(s) discovered") + for res in resources: + res_name = getattr(res, "name", "Unknown") + res_type = res.__class__.__name__ + log.info(f"[Discovery] • {res_name} ({res_type})") # Cache results self._cache[str(self.entry_point)] = resources @@ -363,6 +376,7 @@ def _scan_project_directory(self) -> List[DeployableResource]: "dist/", ".tox/", "node_modules/", + ".flash/", ] ): continue diff --git a/src/tetra_rp/core/resources/load_balancer_sls_resource.py b/src/tetra_rp/core/resources/load_balancer_sls_resource.py index a7d274dd..fa30963a 100644 --- a/src/tetra_rp/core/resources/load_balancer_sls_resource.py +++ b/src/tetra_rp/core/resources/load_balancer_sls_resource.py @@ -254,6 +254,11 @@ async def _do_deploy(self) -> "LoadBalancerSlsResource": return self try: + # Mark this endpoint as a mothership (triggers auto-provisioning on boot) + if self.env is None: + self.env = {} + self.env["FLASH_IS_MOTHERSHIP"] = "true" + # Call parent deploy (creates endpoint via RunPod API) log.info(f"Deploying LB endpoint {self.name}...") deployed = await super()._do_deploy() diff --git a/src/tetra_rp/core/resources/resource_manager.py b/src/tetra_rp/core/resources/resource_manager.py index ed3c3184..738482c5 100644 --- a/src/tetra_rp/core/resources/resource_manager.py +++ b/src/tetra_rp/core/resources/resource_manager.py @@ -252,10 +252,22 @@ async def get_or_deploy_resource( if not existing.is_deployed(): log.warning(f"{existing} is no longer valid, redeploying.") self._remove_resource(resource_key) - deployed_resource = await self._deploy_with_error_context(config) - log.info(f"URL: {deployed_resource.url}") - self._add_resource(resource_key, deployed_resource) - return deployed_resource + try: + deployed_resource = await self._deploy_with_error_context( + config + ) + log.info(f"URL: {deployed_resource.url}") + self._add_resource(resource_key, deployed_resource) + return deployed_resource + except Exception: + # Universal rule: If resource was created (has ID), track it for cleanup + if hasattr(config, "id") and config.id: + log.warning( + f"Deployment failed but resource '{config.name}' was created with ID {config.id}, " + f"caching for cleanup" + ) + self._add_resource(resource_key, config) + raise # Check for config drift stored_config_hash = self._resource_configs.get(resource_key, "") @@ -288,12 +300,22 @@ async def get_or_deploy_resource( "redeploying" ) await existing.undeploy() - deployed_resource = await self._deploy_with_error_context( - config - ) - log.info(f"URL: {deployed_resource.url}") - self._add_resource(resource_key, deployed_resource) - return deployed_resource + try: + deployed_resource = await self._deploy_with_error_context( + config + ) + log.info(f"URL: {deployed_resource.url}") + self._add_resource(resource_key, deployed_resource) + return deployed_resource + except Exception: + # Universal rule: If resource was created (has ID), track it for cleanup + if hasattr(config, "id") and config.id: + log.warning( + f"Deployment failed but resource '{config.name}' was created with ID {config.id}, " + f"caching for cleanup" + ) + self._add_resource(resource_key, config) + raise # Config unchanged, reuse existing log.debug(f"{existing} exists, reusing (config unchanged)") @@ -305,10 +327,20 @@ async def get_or_deploy_resource( f"Resource NOT found in cache, deploying new: {resource_key}\n" f" Searched in keys: {list(self._resources.keys())}" ) - deployed_resource = await self._deploy_with_error_context(config) - log.info(f"URL: {deployed_resource.url}") - self._add_resource(resource_key, deployed_resource) - return deployed_resource + try: + deployed_resource = await self._deploy_with_error_context(config) + log.info(f"URL: {deployed_resource.url}") + self._add_resource(resource_key, deployed_resource) + return deployed_resource + except Exception: + # Universal rule: If resource was created (has ID), track it for cleanup + if hasattr(config, "id") and config.id: + log.warning( + f"Deployment failed but resource '{config.name}' was created with ID {config.id}, " + f"caching for cleanup" + ) + self._add_resource(resource_key, config) + raise @asynccontextmanager async def resource_lock(self, uid: str): diff --git a/src/tetra_rp/runtime/lb_handler.py b/src/tetra_rp/runtime/lb_handler.py index 4193dd02..bde79d8a 100644 --- a/src/tetra_rp/runtime/lb_handler.py +++ b/src/tetra_rp/runtime/lb_handler.py @@ -39,7 +39,9 @@ def create_lb_handler( - route_registry: Dict[tuple[str, str], Callable], include_execute: bool = False + route_registry: Dict[tuple[str, str], Callable], + include_execute: bool = False, + lifespan: Callable = None, ) -> FastAPI: """Create FastAPI app with routes from registry. @@ -49,11 +51,12 @@ def create_lb_handler( include_execute: Whether to register /execute endpoint for @remote execution. Only used for LiveLoadBalancer (local development). Deployed endpoints should not expose /execute for security. + lifespan: Optional lifespan context manager for startup/shutdown hooks. Returns: Configured FastAPI application with routes registered. """ - app = FastAPI(title="Flash Load-Balanced Handler") + app = FastAPI(title="Flash Load-Balanced Handler", lifespan=lifespan) # Register /execute endpoint for @remote stub execution (if enabled) if include_execute: diff --git a/src/tetra_rp/runtime/manifest_client.py b/src/tetra_rp/runtime/manifest_client.py index bfe69ca8..dc1a5519 100644 --- a/src/tetra_rp/runtime/manifest_client.py +++ b/src/tetra_rp/runtime/manifest_client.py @@ -1,4 +1,4 @@ -"""HTTP client for mothership manifest directory API.""" +"""HTTP client for mothership /manifest endpoint API.""" import asyncio import logging @@ -17,14 +17,13 @@ class ManifestClient: - """HTTP client for querying mothership manifest directory service. + """HTTP client for querying mothership /manifest endpoint. - Fetches the endpoint registry that maps resource_config names to their - deployment URLs. This is the "manifest directory service" - an endpoint - registry showing where resources are deployed. + Fetches the endpoint registry from the mothership's /manifest endpoint, + which maps resource_config names to their deployment URLs. - The directory maps resource_config names to their endpoint URLs. - Example: {"gpu_config": "https://api.runpod.io/v2/abc123"} + The manifest maps resource_config names to their endpoint URLs. + Example: {"gpu_config": "https://gpu-worker.api.runpod.ai"} """ def __init__( @@ -37,33 +36,37 @@ def __init__( Args: mothership_url: Base URL of mothership endpoint. Defaults to - FLASH_MOTHERSHIP_URL environment variable. + constructed from FLASH_MOTHERSHIP_ID environment variable. timeout: Request timeout in seconds (default: 10). max_retries: Maximum retry attempts (default: 3). Raises: ValueError: If mothership_url not provided and env var not set. """ - self.mothership_url = mothership_url or os.getenv("FLASH_MOTHERSHIP_URL") - if not self.mothership_url: - raise ValueError( - "mothership_url required: pass mothership_url or set " - "FLASH_MOTHERSHIP_URL environment variable" - ) + if mothership_url: + self.mothership_url = mothership_url + else: + mothership_id = os.getenv("FLASH_MOTHERSHIP_ID") + if not mothership_id: + raise ValueError( + "mothership_url required: pass mothership_url or set " + "FLASH_MOTHERSHIP_ID environment variable" + ) + self.mothership_url = f"https://{mothership_id}.api.runpod.ai" self.timeout = timeout self.max_retries = max_retries self._client: Optional[httpx.AsyncClient] = None async def get_directory(self) -> Dict[str, str]: - """Fetch endpoint directory from mothership. + """Fetch manifest from mothership /manifest endpoint. Returns: Dictionary mapping resource_config_name → endpoint_url. - Example: {"gpu_config": "https://api.runpod.io/v2/abc123"} + Example: {"gpu_config": "https://gpu-worker.api.runpod.ai"} Raises: - ManifestServiceUnavailableError: If manifest directory service unavailable after retries. + ManifestServiceUnavailableError: If /manifest endpoint unavailable after retries. """ if httpx is None: raise ImportError( @@ -76,7 +79,7 @@ async def get_directory(self) -> Dict[str, str]: try: client = await self._get_client() response = await client.get( - f"{self.mothership_url}/directory", + f"{self.mothership_url}/manifest", timeout=self.timeout, ) diff --git a/src/tetra_rp/runtime/mothership_provisioner.py b/src/tetra_rp/runtime/mothership_provisioner.py new file mode 100644 index 00000000..8035bf25 --- /dev/null +++ b/src/tetra_rp/runtime/mothership_provisioner.py @@ -0,0 +1,462 @@ +"""Mothership auto-provisioning logic with manifest reconciliation.""" + +import hashlib +import json +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +from tetra_rp.core.resources.base import DeployableResource +from tetra_rp.core.resources.resource_manager import ResourceManager + +from .state_manager_client import StateManagerClient + +logger = logging.getLogger(__name__) + + +@dataclass +class ManifestDiff: + """Result of manifest reconciliation.""" + + new: List[str] # Resources to deploy + changed: List[str] # Resources to update + removed: List[str] # Resources to delete + unchanged: List[str] # Resources to skip + + +def get_mothership_url() -> str: + """Construct mothership URL from RUNPOD_ENDPOINT_ID env var. + + Returns: + Mothership URL in format: https://{endpoint_id}.api.runpod.ai + + Raises: + RuntimeError: If RUNPOD_ENDPOINT_ID not set + """ + endpoint_id = os.getenv("RUNPOD_ENDPOINT_ID") + if not endpoint_id: + raise RuntimeError("RUNPOD_ENDPOINT_ID environment variable not set") + return f"https://{endpoint_id}.api.runpod.ai" + + +def is_mothership() -> bool: + """Check if current endpoint is mothership. + + Returns: + True if FLASH_IS_MOTHERSHIP env var is 'true' + """ + return os.getenv("FLASH_IS_MOTHERSHIP", "").lower() == "true" + + +def load_manifest(manifest_path: Optional[Path] = None) -> Dict[str, Any]: + """Load flash_manifest.json. + + Args: + manifest_path: Explicit path to manifest. Tries env var and + auto-detection if not provided. + + Returns: + Manifest dictionary + + Raises: + FileNotFoundError: If manifest not found + """ + paths_to_try = [] + + # Explicit path + if manifest_path: + paths_to_try.append(manifest_path) + + # Environment variable + env_path = os.getenv("FLASH_MANIFEST_PATH") + if env_path: + paths_to_try.append(Path(env_path)) + + # Auto-detection: same directory as this file, or cwd + paths_to_try.extend( + [ + Path(__file__).parent.parent.parent / "flash_manifest.json", + Path.cwd() / "flash_manifest.json", + ] + ) + + # Try each path + for path in paths_to_try: + if path and path.exists(): + try: + with open(path) as f: + manifest_dict = json.load(f) + logger.debug(f"Manifest loaded from {path}") + return manifest_dict + except Exception as e: + logger.warning(f"Failed to load manifest from {path}: {e}") + continue + + raise FileNotFoundError( + f"flash_manifest.json not found. Searched paths: {paths_to_try}" + ) + + +def compute_resource_hash(resource_data: Dict[str, Any]) -> str: + """Compute hash of resource configuration for drift detection. + + Args: + resource_data: Resource configuration from manifest + + Returns: + SHA-256 hash of resource config + """ + # Convert to JSON and hash to detect changes + config_json = json.dumps(resource_data, sort_keys=True) + return hashlib.sha256(config_json.encode()).hexdigest() + + +def filter_resources_by_manifest( + all_resources: Dict[str, DeployableResource], + manifest: Dict[str, Any], +) -> Dict[str, DeployableResource]: + """Filter cached resources to only those defined in manifest. + + Prevents stale cache entries from being deployed by checking: + 1. Resource name exists in manifest + 2. Resource type matches manifest entry + + Stale entries can occur when codebase is refactored but the resource + cache still contains endpoints from an older version. + + Args: + all_resources: All resources from ResourceManager cache + manifest: Current deployment manifest + + Returns: + Filtered dict containing only manifest-matching resources + """ + manifest_resources = manifest.get("resources", {}) + filtered = {} + removed_count = 0 + + for key, resource in all_resources.items(): + resource_name = resource.name if hasattr(resource, "name") else None + + if not resource_name: + logger.warning(f"Skipping cached resource without name: {key}") + removed_count += 1 + continue + + # Check if resource exists in manifest + if resource_name not in manifest_resources: + logger.info( + f"Removing stale cached resource '{resource_name}' " + f"(not in current manifest)" + ) + removed_count += 1 + continue + + # Check if type matches + manifest_entry = manifest_resources[resource_name] + expected_type = manifest_entry.get("resource_type") + actual_type = resource.__class__.__name__ + + if expected_type and expected_type != actual_type: + logger.warning( + f"Removing stale cached resource '{resource_name}' " + f"(type mismatch: cached={actual_type}, manifest={expected_type})" + ) + removed_count += 1 + continue + + filtered[key] = resource + + if removed_count > 0: + logger.info( + f"Cache validation: Removed {removed_count} stale " + f"resource(s) not matching manifest" + ) + + return filtered + + +def reconcile_manifests( + local_manifest: Dict[str, Any], + persisted_manifest: Optional[Dict[str, Any]], +) -> ManifestDiff: + """Compare local and persisted manifests to detect changes. + + Args: + local_manifest: Current manifest from flash_manifest.json + persisted_manifest: Last known manifest from State Manager (None if first boot) + + Returns: + ManifestDiff with categorized resources + """ + local_resources = local_manifest.get("resources", {}) + persisted_resources = ( + persisted_manifest.get("resources", {}) if persisted_manifest else {} + ) + + new = [] + changed = [] + unchanged = [] + + for name, local_data in local_resources.items(): + if name not in persisted_resources: + new.append(name) + else: + # Compare config hashes to detect changes + local_hash = compute_resource_hash(local_data) + persisted_hash = persisted_resources[name].get("config_hash") + + if local_hash != persisted_hash: + changed.append(name) + else: + unchanged.append(name) + + # Detect removed resources (in persisted, not in local) + removed = [name for name in persisted_resources if name not in local_resources] + + return ManifestDiff(new=new, changed=changed, removed=removed, unchanged=unchanged) + + +def create_resource_from_manifest( + resource_name: str, + resource_data: Dict[str, Any], + mothership_url: str, +) -> DeployableResource: + """Create DeployableResource config from manifest entry. + + Args: + resource_name: Name of the resource + resource_data: Resource configuration from manifest + mothership_url: Mothership URL to set in child env vars + + Returns: + Configured DeployableResource ready for deployment + + Raises: + ValueError: If resource type not supported + """ + from tetra_rp.core.resources.live_serverless import ( + CpuLiveLoadBalancer, + LiveLoadBalancer, + ) + from tetra_rp.core.resources.load_balancer_sls_resource import ( + LoadBalancerSlsResource, + ) + from tetra_rp.core.resources.serverless import ServerlessResource + + resource_type = resource_data.get("resource_type", "ServerlessResource") + + # Support both Serverless and LoadBalancer resource types + if resource_type not in [ + "ServerlessResource", + "LiveServerless", + "LoadBalancerSlsResource", + "LiveLoadBalancer", + "CpuLiveLoadBalancer", + ]: + raise ValueError( + f"Unsupported resource type for auto-provisioning: {resource_type}" + ) + + # Create resource with mothership environment variables + # Note: Manifest doesn't contain full deployment config (image, workers, etc.) + # For now, create a minimal config with required fields + # TODO: Enhance manifest to include deployment config (image, workers, GPU type, etc.) + + # Create appropriate resource type based on manifest entry + import os + + env = { + "FLASH_RESOURCE_NAME": resource_name, + "FLASH_MOTHERSHIP_ID": os.getenv("RUNPOD_ENDPOINT_ID"), + } + + # Add "tmp-" prefix for test-mothership deployments + # Check environment variable set by test-mothership command + + is_test_mothership = os.getenv("FLASH_IS_TEST_MOTHERSHIP", "").lower() == "true" + + if is_test_mothership and not resource_name.startswith("tmp-"): + prefixed_name = f"tmp-{resource_name}" + logger.info(f"Test mode: Using temporary name '{prefixed_name}'") + else: + prefixed_name = resource_name + + if resource_type == "CpuLiveLoadBalancer": + resource = CpuLiveLoadBalancer(name=prefixed_name, env=env) + elif resource_type == "LiveLoadBalancer": + resource = LiveLoadBalancer(name=prefixed_name, env=env) + elif resource_type == "LoadBalancerSlsResource": + resource = LoadBalancerSlsResource(name=prefixed_name, env=env) + else: + # ServerlessResource and LiveServerless + resource = ServerlessResource(name=prefixed_name, env=env) + + return resource + + +async def provision_children( + manifest_path: Path, + mothership_url: str, + state_client: StateManagerClient, +) -> None: + """Provision all child resources with reconciliation. + + Orchestrates deployment/update/delete of resources based on manifest differences. + + Args: + manifest_path: Path to flash_manifest.json + mothership_url: Mothership endpoint URL to set on children + state_client: State Manager API client + """ + try: + # Load local manifest + local_manifest = load_manifest(manifest_path) + + # Get persisted manifest from State Manager + mothership_id = os.getenv("RUNPOD_ENDPOINT_ID") + if not mothership_id: + logger.error("RUNPOD_ENDPOINT_ID not set, cannot load persisted manifest") + return + + persisted_manifest = await state_client.get_persisted_manifest(mothership_id) + + # Reconcile manifests + logger.info( + f"Starting reconciliation: {len(local_manifest.get('resources', {}))} manifest resources" + ) + + diff = reconcile_manifests(local_manifest, persisted_manifest) + + logger.info( + f"Reconciliation plan: {len(diff.new)} to deploy, " + f"{len(diff.changed)} to update, " + f"{len(diff.removed)} to remove, " + f"{len(diff.unchanged)} unchanged" + ) + + manager = ResourceManager() + + # Filter cached resources to prevent stale entries from being deployed + # This ensures resources from old codebase versions don't get redeployed + all_cached = manager.list_all_resources() + if all_cached: + valid_cached = filter_resources_by_manifest(all_cached, local_manifest) + logger.info( + f"Cache validation: {len(all_cached)} cached, " + f"{len(valid_cached)} valid, " + f"{len(local_manifest.get('resources', {}))} in manifest" + ) + + # Deploy NEW resources + for resource_name in diff.new: + try: + resource_data = local_manifest["resources"][resource_name] + config = create_resource_from_manifest( + resource_name, resource_data, mothership_url + ) + deployed = await manager.get_or_deploy_resource(config) + + # Update State Manager + await state_client.update_resource_state( + mothership_id, + resource_name, + { + "config_hash": compute_resource_hash(resource_data), + "endpoint_url": deployed.endpoint_url + if hasattr(deployed, "endpoint_url") + else deployed.url, + "status": "deployed", + }, + ) + logger.info(f"Deployed new resource: {resource_name}") + + except Exception as e: + logger.error(f"Failed to deploy {resource_name}: {e}") + try: + await state_client.update_resource_state( + mothership_id, + resource_name, + {"status": "failed", "error": str(e)}, + ) + except Exception as sm_error: + logger.error( + f"Failed to update State Manager for {resource_name}: {sm_error}" + ) + + # Update CHANGED resources + for resource_name in diff.changed: + try: + resource_data = local_manifest["resources"][resource_name] + config = create_resource_from_manifest( + resource_name, resource_data, mothership_url + ) + updated = await manager.get_or_deploy_resource(config) + + await state_client.update_resource_state( + mothership_id, + resource_name, + { + "config_hash": compute_resource_hash(resource_data), + "endpoint_url": updated.endpoint_url + if hasattr(updated, "endpoint_url") + else updated.url, + "status": "updated", + }, + ) + logger.info(f"Updated resource: {resource_name}") + + except Exception as e: + logger.error(f"Failed to update {resource_name}: {e}") + try: + await state_client.update_resource_state( + mothership_id, + resource_name, + {"status": "failed", "error": str(e)}, + ) + except Exception as sm_error: + logger.error( + f"Failed to update State Manager for {resource_name}: {sm_error}" + ) + + # Delete REMOVED resources + for resource_name in diff.removed: + try: + # Find resource in ResourceManager + matches = manager.find_resources_by_name(resource_name) + if matches: + resource_id, _ = matches[0] + result = await manager.undeploy_resource(resource_id, resource_name) + + if result["success"]: + try: + await state_client.remove_resource_state( + mothership_id, resource_name + ) + except Exception as sm_error: + logger.error( + f"Failed to remove {resource_name} from State Manager: {sm_error}" + ) + logger.info(f"Deleted removed resource: {resource_name}") + else: + logger.error( + f"Failed to delete {resource_name}: {result['message']}" + ) + else: + logger.warning( + f"Removed resource {resource_name} not found in ResourceManager" + ) + + except Exception as e: + logger.error(f"Failed to delete {resource_name}: {e}") + + logger.info("=" * 60) + logger.info("Provisioning complete - All child endpoints deployed") + logger.info(f"Total endpoints: {len(local_manifest.get('resources', {}))}") + logger.info("Test phase: Manifest updated with child endpoint URLs") + logger.info("=" * 60) + + except Exception as e: + logger.error(f"Provisioning failed: {e}", exc_info=True) diff --git a/src/tetra_rp/runtime/service_registry.py b/src/tetra_rp/runtime/service_registry.py index ddcbcd84..32ae0216 100644 --- a/src/tetra_rp/runtime/service_registry.py +++ b/src/tetra_rp/runtime/service_registry.py @@ -38,9 +38,14 @@ def __init__( manifest_path: Path to flash_manifest.json. Defaults to FLASH_MANIFEST_PATH env var or auto-detection. directory_client: Manifest service client for mothership API. If None, creates one - from FLASH_MOTHERSHIP_URL env var. + from FLASH_MOTHERSHIP_ID env var. cache_ttl: Directory cache lifetime in seconds (default: 300). + Environment Variables (for local vs remote detection): + FLASH_RESOURCE_NAME: Resource config name for this endpoint (child endpoints only). + Identifies which resource config this endpoint represents in the manifest. + RUNPOD_ENDPOINT_ID: Endpoint ID (used as fallback for mothership identification). + Raises: FileNotFoundError: If manifest_path doesn't exist. ValueError: If required env vars missing for directory_client. @@ -62,15 +67,23 @@ def __init__( # Initialize manifest client if directory_client is None: - mothership_url = os.getenv("FLASH_MOTHERSHIP_URL") - if mothership_url: - directory_client = ManifestClient(mothership_url=mothership_url) + mothership_id = os.getenv("FLASH_MOTHERSHIP_ID") + if mothership_id: + try: + directory_client = ManifestClient() + except ValueError as e: + logger.warning(f"Failed to initialize manifest client: {e}") + directory_client = None else: - logger.warning("FLASH_MOTHERSHIP_URL not set, directory unavailable") + logger.debug("FLASH_MOTHERSHIP_ID not set, directory unavailable") directory_client = None self._directory_client = directory_client - self._current_endpoint = os.getenv("RUNPOD_ENDPOINT_ID") + # Child endpoints use FLASH_RESOURCE_NAME to identify which resource config they represent + # Mothership doesn't have FLASH_RESOURCE_NAME, so falls back to RUNPOD_ENDPOINT_ID + self._current_endpoint = os.getenv("FLASH_RESOURCE_NAME") or os.getenv( + "RUNPOD_ENDPOINT_ID" + ) def _load_manifest(self, manifest_path: Optional[Path]) -> None: """Load flash_manifest.json. diff --git a/src/tetra_rp/runtime/state_manager_client.py b/src/tetra_rp/runtime/state_manager_client.py new file mode 100644 index 00000000..860c1e87 --- /dev/null +++ b/src/tetra_rp/runtime/state_manager_client.py @@ -0,0 +1,263 @@ +"""HTTP client for State Manager API to persist and reconcile manifests.""" + +import asyncio +import logging +import os +from typing import Any, Dict, Optional + +try: + import httpx +except ImportError: + httpx = None + +from .config import DEFAULT_MAX_RETRIES, DEFAULT_REQUEST_TIMEOUT +from .exceptions import ManifestServiceUnavailableError + +logger = logging.getLogger(__name__) + + +class StateManagerClient: + """HTTP client for State Manager API. + + The State Manager persists manifest state and provides reconciliation + capabilities for the mothership to track deployed resources across boots. + """ + + def __init__( + self, + api_key: Optional[str] = None, + base_url: str = "https://api.runpod.io", + timeout: int = DEFAULT_REQUEST_TIMEOUT, + max_retries: int = DEFAULT_MAX_RETRIES, + ): + """Initialize State Manager client. + + Args: + api_key: RunPod API key. Defaults to RUNPOD_API_KEY env var. + base_url: Base URL for State Manager API. + timeout: Request timeout in seconds. + max_retries: Maximum retry attempts. + + Raises: + ValueError: If api_key not provided and env var not set. + """ + self.api_key = api_key or os.getenv("RUNPOD_API_KEY") + if not self.api_key: + raise ValueError( + "api_key required: pass api_key or set RUNPOD_API_KEY environment variable" + ) + + self.base_url = base_url + self.timeout = timeout + self.max_retries = max_retries + self._client: Optional[httpx.AsyncClient] = None + + async def get_persisted_manifest( + self, mothership_id: str + ) -> Optional[Dict[str, Any]]: + """Fetch persisted manifest from State Manager. + + Args: + mothership_id: ID of the mothership endpoint. + + Returns: + Manifest dict or None if not found (first boot). + + Raises: + ManifestServiceUnavailableError: If State Manager unavailable after retries. + """ + if httpx is None: + raise ImportError( + "httpx required for StateManagerClient. Install with: pip install httpx" + ) + + last_exception: Optional[Exception] = None + + for attempt in range(self.max_retries): + try: + client = await self._get_client() + response = await client.get( + f"{self.base_url}/api/v1/flash/manifests/{mothership_id}", + headers={"Authorization": f"Bearer {self.api_key}"}, + timeout=self.timeout, + ) + + if response.status_code == 404: + logger.debug( + f"No persisted manifest found for {mothership_id} (first boot)" + ) + return None + + if response.status_code >= 400: + raise ManifestServiceUnavailableError( + f"State Manager returned {response.status_code}: " + f"{response.text[:200]}" + ) + + data = response.json() + logger.debug(f"Persisted manifest loaded for {mothership_id}") + return data + + except ( + asyncio.TimeoutError, + ManifestServiceUnavailableError, + Exception, + ) as e: + last_exception = e + if attempt < self.max_retries - 1: + backoff = 2**attempt + logger.warning( + f"State Manager request failed (attempt {attempt + 1}): {e}, " + f"retrying in {backoff}s..." + ) + await asyncio.sleep(backoff) + continue + + raise ManifestServiceUnavailableError( + f"Failed to fetch persisted manifest after {self.max_retries} attempts: " + f"{last_exception}" + ) + + async def update_resource_state( + self, + mothership_id: str, + resource_name: str, + resource_data: Dict[str, Any], + ) -> None: + """Update single resource entry in State Manager. + + Args: + mothership_id: ID of the mothership endpoint. + resource_name: Name of the resource. + resource_data: Resource metadata (config_hash, endpoint_url, status, etc). + + Raises: + ManifestServiceUnavailableError: If State Manager unavailable. + """ + if httpx is None: + raise ImportError( + "httpx required for StateManagerClient. Install with: pip install httpx" + ) + + last_exception: Optional[Exception] = None + + for attempt in range(self.max_retries): + try: + client = await self._get_client() + response = await client.put( + f"{self.base_url}/api/v1/flash/manifests/{mothership_id}/resources/{resource_name}", + headers={"Authorization": f"Bearer {self.api_key}"}, + json=resource_data, + timeout=self.timeout, + ) + + if response.status_code >= 400: + raise ManifestServiceUnavailableError( + f"State Manager returned {response.status_code}: " + f"{response.text[:200]}" + ) + + logger.debug( + f"Updated resource state in State Manager: {mothership_id}/{resource_name}" + ) + return + + except ( + asyncio.TimeoutError, + ManifestServiceUnavailableError, + Exception, + ) as e: + last_exception = e + if attempt < self.max_retries - 1: + backoff = 2**attempt + logger.warning( + f"State Manager request failed (attempt {attempt + 1}): {e}, " + f"retrying in {backoff}s..." + ) + await asyncio.sleep(backoff) + continue + + raise ManifestServiceUnavailableError( + f"Failed to update resource state after {self.max_retries} attempts: " + f"{last_exception}" + ) + + async def remove_resource_state( + self, mothership_id: str, resource_name: str + ) -> None: + """Remove resource entry from State Manager. + + Args: + mothership_id: ID of the mothership endpoint. + resource_name: Name of the resource. + + Raises: + ManifestServiceUnavailableError: If State Manager unavailable. + """ + if httpx is None: + raise ImportError( + "httpx required for StateManagerClient. Install with: pip install httpx" + ) + + last_exception: Optional[Exception] = None + + for attempt in range(self.max_retries): + try: + client = await self._get_client() + response = await client.delete( + f"{self.base_url}/api/v1/flash/manifests/{mothership_id}/resources/{resource_name}", + headers={"Authorization": f"Bearer {self.api_key}"}, + timeout=self.timeout, + ) + + if response.status_code >= 400: + raise ManifestServiceUnavailableError( + f"State Manager returned {response.status_code}: " + f"{response.text[:200]}" + ) + + logger.debug( + f"Removed resource state from State Manager: {mothership_id}/{resource_name}" + ) + return + + except ( + asyncio.TimeoutError, + ManifestServiceUnavailableError, + Exception, + ) as e: + last_exception = e + if attempt < self.max_retries - 1: + backoff = 2**attempt + logger.warning( + f"State Manager request failed (attempt {attempt + 1}): {e}, " + f"retrying in {backoff}s..." + ) + await asyncio.sleep(backoff) + continue + + raise ManifestServiceUnavailableError( + f"Failed to remove resource state after {self.max_retries} attempts: " + f"{last_exception}" + ) + + async def _get_client(self) -> httpx.AsyncClient: + """Get or create HTTP client with proper configuration.""" + if self._client is None or self._client.is_closed: + timeout = httpx.Timeout(self.timeout) + self._client = httpx.AsyncClient(timeout=timeout) + + return self._client + + async def close(self) -> None: + """Close HTTP session.""" + if self._client and not self._client.is_closed: + await self._client.aclose() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() diff --git a/tests/integration/test_lb_remote_execution.py b/tests/integration/test_lb_remote_execution.py index 2aca85d0..00068d01 100644 --- a/tests/integration/test_lb_remote_execution.py +++ b/tests/integration/test_lb_remote_execution.py @@ -178,12 +178,16 @@ def test_live_load_balancer_handler_includes_execute_endpoint(self): "test-api": { "resource_type": "LiveLoadBalancer", "handler_file": "handler_test_api.py", + "is_load_balanced": True, + "is_live_resource": True, "functions": [ { "name": "process_data", "module": "api.endpoints", "is_async": True, "is_class": False, + "is_load_balanced": True, + "is_live_resource": True, "http_method": "POST", "http_path": "/api/process", } @@ -204,7 +208,7 @@ def test_live_load_balancer_handler_includes_execute_endpoint(self): # Verify the handler includes include_execute=True for LiveLoadBalancer assert "include_execute=True" in handler_code assert ( - "create_lb_handler(ROUTE_REGISTRY, include_execute=True)" + "create_lb_handler(ROUTE_REGISTRY, include_execute=True, lifespan=lifespan)" in handler_code ) @@ -228,12 +232,16 @@ def test_deployed_load_balancer_handler_excludes_execute_endpoint(self): "api-service": { "resource_type": "LoadBalancerSlsResource", "handler_file": "handler_api_service.py", + "is_load_balanced": True, + "is_live_resource": False, "functions": [ { "name": "process_data", "module": "api.endpoints", "is_async": True, "is_class": False, + "is_load_balanced": True, + "is_live_resource": False, "http_method": "POST", "http_path": "/api/process", } @@ -254,7 +262,7 @@ def test_deployed_load_balancer_handler_excludes_execute_endpoint(self): # Verify the handler includes include_execute=False for deployed endpoints assert "include_execute=False" in handler_code assert ( - "create_lb_handler(ROUTE_REGISTRY, include_execute=False)" + "create_lb_handler(ROUTE_REGISTRY, include_execute=False, lifespan=lifespan)" in handler_code ) @@ -305,6 +313,63 @@ def get_status(): assert "deployed-api" in scanner.resource_types assert scanner.resource_types["deployed-api"] == "LoadBalancerSlsResource" + def test_handler_generation_with_numeric_module_paths(self): + """Test that LB handlers use importlib for numeric module paths.""" + from tetra_rp.cli.commands.build_utils.lb_handler_generator import ( + LBHandlerGenerator, + ) + from datetime import datetime, timezone + from pathlib import Path + import tempfile + + # Create a manifest with numeric module paths + manifest = { + "version": "1.0", + "generated_at": datetime.now(timezone.utc) + .isoformat() + .replace("+00:00", "Z"), + "project_name": "test-project", + "resources": { + "test-api": { + "resource_type": "LoadBalancerSlsResource", + "handler_file": "handler_test_api.py", + "is_load_balanced": True, + "is_live_resource": False, + "functions": [ + { + "name": "gpu_health", + "module": "03_advanced_workers.05_load_balancer.workers.gpu.endpoint", + "is_async": True, + "is_class": False, + "is_load_balanced": True, + "is_live_resource": False, + "http_method": "GET", + "http_path": "/health", + } + ], + } + }, + } + + with tempfile.TemporaryDirectory() as tmpdir: + build_dir = Path(tmpdir) + generator = LBHandlerGenerator(manifest, build_dir) + handlers = generator.generate_handlers() + + assert len(handlers) == 1 + handler_path = handlers[0] + handler_code = handler_path.read_text() + + # Verify importlib pattern is used + assert "import importlib" in handler_code + assert ( + "gpu_health = importlib.import_module('03_advanced_workers.05_load_balancer.workers.gpu.endpoint').gpu_health" + in handler_code + ) + + # Verify no invalid from syntax + assert "from 03_advanced_workers" not in handler_code + class TestManifestEndpointIntegration: """Integration tests for GET /manifest endpoint.""" diff --git a/tests/integration/test_mothership_provisioning.py b/tests/integration/test_mothership_provisioning.py new file mode 100644 index 00000000..058c3074 --- /dev/null +++ b/tests/integration/test_mothership_provisioning.py @@ -0,0 +1,478 @@ +"""Integration tests for mothership auto-provisioning with manifest reconciliation.""" + +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tetra_rp.runtime.mothership_provisioner import ( + compute_resource_hash, + provision_children, +) +from tetra_rp.runtime.state_manager_client import StateManagerClient + + +class TestMothershipProvisioningFlow: + """Integration tests for mothership provisioning workflow.""" + + @pytest.mark.asyncio + async def test_provision_children_first_boot(self): + """Test provisioning on first boot (no persisted manifest). + + Scenario: + - Mothership starts for the first time + - No persisted manifest in State Manager + - All resources in local manifest should be deployed as NEW + """ + # Setup: Create local manifest + local_manifest = { + "version": "1.0", + "resources": { + "gpu_worker": { + "resource_type": "ServerlessResource", + "config": "gpu_v1", + }, + "cpu_worker": { + "resource_type": "ServerlessResource", + "config": "cpu_v1", + }, + }, + } + + # Mock StateManagerClient + mock_state_client = AsyncMock(spec=StateManagerClient) + mock_state_client.get_persisted_manifest.return_value = None # No persisted + mock_state_client.update_resource_state = AsyncMock() + + # Mock ResourceManager + mock_gpu_resource = MagicMock() + mock_gpu_resource.endpoint_url = "https://gpu-worker.api.runpod.ai" + mock_cpu_resource = MagicMock() + mock_cpu_resource.endpoint_url = "https://cpu-worker.api.runpod.ai" + + with ( + patch("tetra_rp.runtime.mothership_provisioner.load_manifest") as mock_load, + patch( + "tetra_rp.runtime.mothership_provisioner.ResourceManager" + ) as mock_rm_class, + patch.dict( + "os.environ", + {"RUNPOD_ENDPOINT_ID": "mothership-123"}, + ), + ): + mock_load.return_value = local_manifest + + mock_manager = MagicMock() + mock_manager.get_or_deploy_resource = AsyncMock( + side_effect=[mock_gpu_resource, mock_cpu_resource] + ) + mock_rm_class.return_value = mock_manager + + # Execute + with tempfile.TemporaryDirectory() as tmpdir: + manifest_path = Path(tmpdir) / "flash_manifest.json" + mothership_url = "https://mothership-123.api.runpod.ai" + + await provision_children( + manifest_path, mothership_url, mock_state_client + ) + + # Verify: Both resources deployed + assert mock_manager.get_or_deploy_resource.call_count == 2 + assert mock_state_client.update_resource_state.call_count == 2 + + # Verify: State Manager updated with deployment info + calls = mock_state_client.update_resource_state.call_args_list + assert calls[0][0][1] == "gpu_worker" # resource_name + assert calls[1][0][1] == "cpu_worker" + + @pytest.mark.asyncio + async def test_provision_children_with_changes(self): + """Test provisioning with changed resources. + + Scenario: + - Mothership boots with updated manifest + - Some resources have changed config (different hash) + - Changed resources should be updated, unchanged skipped + """ + gpu_old_data = {"resource_type": "ServerlessResource", "config": "gpu_v1"} + gpu_new_data = {"resource_type": "ServerlessResource", "config": "gpu_v2"} + cpu_data = {"resource_type": "ServerlessResource", "config": "cpu_v1"} + + gpu_old_hash = compute_resource_hash(gpu_old_data) + cpu_hash = compute_resource_hash(cpu_data) + + local_manifest = { + "version": "1.0", + "resources": { + "gpu_worker": gpu_new_data, # Changed + "cpu_worker": cpu_data, # Unchanged + }, + } + + persisted_manifest = { + "version": "1.0", + "resources": { + "gpu_worker": { + **gpu_old_data, + "config_hash": gpu_old_hash, + }, + "cpu_worker": { + **cpu_data, + "config_hash": cpu_hash, + }, + }, + } + + # Mock StateManagerClient + mock_state_client = AsyncMock(spec=StateManagerClient) + mock_state_client.get_persisted_manifest.return_value = persisted_manifest + mock_state_client.update_resource_state = AsyncMock() + + # Mock ResourceManager - only called for changed resource + mock_gpu_resource = MagicMock() + mock_gpu_resource.endpoint_url = "https://gpu-worker.api.runpod.ai" + + with ( + patch("tetra_rp.runtime.mothership_provisioner.load_manifest") as mock_load, + patch( + "tetra_rp.runtime.mothership_provisioner.ResourceManager" + ) as mock_rm_class, + patch.dict( + "os.environ", + {"RUNPOD_ENDPOINT_ID": "mothership-123"}, + ), + ): + mock_load.return_value = local_manifest + + mock_manager = MagicMock() + mock_manager.get_or_deploy_resource = AsyncMock( + return_value=mock_gpu_resource + ) + mock_rm_class.return_value = mock_manager + + # Execute + with tempfile.TemporaryDirectory() as tmpdir: + manifest_path = Path(tmpdir) / "flash_manifest.json" + mothership_url = "https://mothership-123.api.runpod.ai" + + await provision_children( + manifest_path, mothership_url, mock_state_client + ) + + # Verify: Only changed resource deployed + assert mock_manager.get_or_deploy_resource.call_count == 1 + # Verify: State Manager updated only for changed resource + assert mock_state_client.update_resource_state.call_count == 1 + assert ( + mock_state_client.update_resource_state.call_args_list[0][0][1] + == "gpu_worker" + ) + + @pytest.mark.asyncio + async def test_provision_children_with_removed_resources(self): + """Test provisioning with removed resources. + + Scenario: + - Manifest previously had 3 resources + - Current manifest has only 2 resources + - Removed resource should be undeployed and removed from State Manager + """ + local_manifest = { + "version": "1.0", + "resources": { + "gpu_worker": { + "resource_type": "ServerlessResource", + "config": "gpu_v1", + }, + }, + } + + persisted_manifest = { + "version": "1.0", + "resources": { + "gpu_worker": { + "resource_type": "ServerlessResource", + "config_hash": "abc123", + }, + "old_worker": { + "resource_type": "ServerlessResource", + "config_hash": "def456", + }, + }, + } + + # Mock StateManagerClient + mock_state_client = AsyncMock(spec=StateManagerClient) + mock_state_client.get_persisted_manifest.return_value = persisted_manifest + mock_state_client.update_resource_state = AsyncMock() + mock_state_client.remove_resource_state = AsyncMock() + + # Mock ResourceManager + mock_gpu_resource = MagicMock() + mock_gpu_resource.endpoint_url = "https://gpu-worker.api.runpod.ai" + + with ( + patch("tetra_rp.runtime.mothership_provisioner.load_manifest") as mock_load, + patch( + "tetra_rp.runtime.mothership_provisioner.ResourceManager" + ) as mock_rm_class, + patch.dict( + "os.environ", + {"RUNPOD_ENDPOINT_ID": "mothership-123"}, + ), + ): + mock_load.return_value = local_manifest + + mock_manager = MagicMock() + mock_manager.get_or_deploy_resource = AsyncMock( + return_value=mock_gpu_resource + ) + # find_resources_by_name returns list of tuples: (resource_id, resource) + mock_manager.find_resources_by_name = MagicMock( + return_value=[("resource-id-123", "old_worker")] + ) + mock_manager.undeploy_resource = AsyncMock(return_value={"success": True}) + mock_rm_class.return_value = mock_manager + + # Execute + with tempfile.TemporaryDirectory() as tmpdir: + manifest_path = Path(tmpdir) / "flash_manifest.json" + mothership_url = "https://mothership-123.api.runpod.ai" + + await provision_children( + manifest_path, mothership_url, mock_state_client + ) + + # Verify: Removed resource undeployed + assert mock_manager.undeploy_resource.call_count == 1 + # Verify: State Manager updated to remove old resource + assert mock_state_client.remove_resource_state.call_count == 1 + assert ( + mock_state_client.remove_resource_state.call_args_list[0][0][1] + == "old_worker" + ) + + @pytest.mark.asyncio + async def test_provision_children_deploys_load_balancer_resources(self): + """Test that LoadBalancer resources are provisioned during provisioning. + + Scenario: + - Manifest includes LoadBalancerSlsResource + - Mothership should deploy it as a child resource to RunPod + """ + local_manifest = { + "version": "1.0", + "resources": { + "mothership": { + "resource_type": "LoadBalancerSlsResource", + "config": "lb_v1", + }, + "gpu_worker": { + "resource_type": "ServerlessResource", + "config": "gpu_v1", + }, + }, + } + + # Mock StateManagerClient + mock_state_client = AsyncMock(spec=StateManagerClient) + mock_state_client.get_persisted_manifest.return_value = None + mock_state_client.update_resource_state = AsyncMock() + + # Mock ResourceManager + mock_lb_resource = MagicMock() + mock_lb_resource.endpoint_url = "https://mothership-lb.api.runpod.ai" + mock_gpu_resource = MagicMock() + mock_gpu_resource.endpoint_url = "https://gpu-worker.api.runpod.ai" + + with ( + patch("tetra_rp.runtime.mothership_provisioner.load_manifest") as mock_load, + patch( + "tetra_rp.runtime.mothership_provisioner.create_resource_from_manifest" + ) as mock_create, + patch( + "tetra_rp.runtime.mothership_provisioner.ResourceManager" + ) as mock_rm_class, + patch.dict( + "os.environ", + {"RUNPOD_ENDPOINT_ID": "mothership-123"}, + ), + ): + mock_load.return_value = local_manifest + mock_create.side_effect = [MagicMock(), MagicMock()] + + mock_manager = MagicMock() + mock_manager.get_or_deploy_resource = AsyncMock( + side_effect=[mock_lb_resource, mock_gpu_resource] + ) + mock_rm_class.return_value = mock_manager + + # Execute + with tempfile.TemporaryDirectory() as tmpdir: + manifest_path = Path(tmpdir) / "flash_manifest.json" + mothership_url = "https://mothership-123.api.runpod.ai" + + await provision_children( + manifest_path, mothership_url, mock_state_client + ) + + # Verify: Both mothership LoadBalancer and gpu_worker deployed + assert mock_manager.get_or_deploy_resource.call_count == 2 + # Verify: Both resources in State Manager + assert mock_state_client.update_resource_state.call_count == 2 + resource_names = [ + mock_state_client.update_resource_state.call_args_list[i][0][1] + for i in range(2) + ] + assert "mothership" in resource_names + assert "gpu_worker" in resource_names + + @pytest.mark.asyncio + async def test_provision_children_handles_deployment_errors(self): + """Test that deployment errors don't block other resources. + + Scenario: + - gpu_worker deployment fails + - cpu_worker deployment should still proceed + - State Manager should be updated with error for gpu_worker + """ + local_manifest = { + "version": "1.0", + "resources": { + "gpu_worker": { + "resource_type": "ServerlessResource", + "config": "gpu_v1", + }, + "cpu_worker": { + "resource_type": "ServerlessResource", + "config": "cpu_v1", + }, + }, + } + + # Mock StateManagerClient + mock_state_client = AsyncMock(spec=StateManagerClient) + mock_state_client.get_persisted_manifest.return_value = None + mock_state_client.update_resource_state = AsyncMock() + + # Mock ResourceManager - gpu_worker fails, cpu_worker succeeds + mock_cpu_resource = MagicMock() + mock_cpu_resource.endpoint_url = "https://cpu-worker.api.runpod.ai" + + with ( + patch("tetra_rp.runtime.mothership_provisioner.load_manifest") as mock_load, + patch( + "tetra_rp.runtime.mothership_provisioner.ResourceManager" + ) as mock_rm_class, + patch.dict( + "os.environ", + {"RUNPOD_ENDPOINT_ID": "mothership-123"}, + ), + ): + mock_load.return_value = local_manifest + + mock_manager = MagicMock() + # First call (gpu_worker) raises error, second call (cpu_worker) succeeds + mock_manager.get_or_deploy_resource = AsyncMock( + side_effect=[ + RuntimeError("GPU allocation failed"), + mock_cpu_resource, + ] + ) + mock_rm_class.return_value = mock_manager + + # Execute + with tempfile.TemporaryDirectory() as tmpdir: + manifest_path = Path(tmpdir) / "flash_manifest.json" + mothership_url = "https://mothership-123.api.runpod.ai" + + # Should not raise despite gpu_worker failure + await provision_children( + manifest_path, mothership_url, mock_state_client + ) + + # Verify: Both resources attempted + assert mock_manager.get_or_deploy_resource.call_count == 2 + + # Verify: State Manager updated for both (error for gpu, success for cpu) + assert mock_state_client.update_resource_state.call_count == 2 + + # Verify: Error recorded for gpu_worker + gpu_call = mock_state_client.update_resource_state.call_args_list[0] + assert gpu_call[0][1] == "gpu_worker" + assert "error" in gpu_call[0][2] + assert gpu_call[0][2]["status"] == "failed" + + # Verify: Success recorded for cpu_worker + cpu_call = mock_state_client.update_resource_state.call_args_list[1] + assert cpu_call[0][1] == "cpu_worker" + assert cpu_call[0][2]["status"] == "deployed" + + @pytest.mark.asyncio + async def test_idempotent_provisioning_on_second_boot(self): + """Test that second boot is idempotent (skips unchanged resources). + + Scenario: + - First boot: Deploy gpu_worker, cpu_worker + - Second boot: Both resources unchanged (same hash) + - Second boot should skip both (no deployments) + """ + gpu_data = {"resource_type": "ServerlessResource", "config": "gpu_v1"} + cpu_data = {"resource_type": "ServerlessResource", "config": "cpu_v1"} + + gpu_hash = compute_resource_hash(gpu_data) + cpu_hash = compute_resource_hash(cpu_data) + + local_manifest = { + "version": "1.0", + "resources": { + "gpu_worker": gpu_data, + "cpu_worker": cpu_data, + }, + } + + persisted_manifest = { + "version": "1.0", + "resources": { + "gpu_worker": {**gpu_data, "config_hash": gpu_hash}, + "cpu_worker": {**cpu_data, "config_hash": cpu_hash}, + }, + } + + # Mock StateManagerClient + mock_state_client = AsyncMock(spec=StateManagerClient) + mock_state_client.get_persisted_manifest.return_value = persisted_manifest + mock_state_client.update_resource_state = AsyncMock() + + # Mock ResourceManager - should not be called + with ( + patch("tetra_rp.runtime.mothership_provisioner.load_manifest") as mock_load, + patch( + "tetra_rp.runtime.mothership_provisioner.ResourceManager" + ) as mock_rm_class, + patch.dict( + "os.environ", + {"RUNPOD_ENDPOINT_ID": "mothership-123"}, + ), + ): + mock_load.return_value = local_manifest + + mock_manager = MagicMock() + mock_manager.get_or_deploy_resource = AsyncMock() + mock_rm_class.return_value = mock_manager + + # Execute + with tempfile.TemporaryDirectory() as tmpdir: + manifest_path = Path(tmpdir) / "flash_manifest.json" + mothership_url = "https://mothership-123.api.runpod.ai" + + await provision_children( + manifest_path, mothership_url, mock_state_client + ) + + # Verify: No deployments (all unchanged) + assert mock_manager.get_or_deploy_resource.call_count == 0 + # Verify: State Manager not updated + assert mock_state_client.update_resource_state.call_count == 0 diff --git a/tests/unit/cli/commands/build_utils/test_manifest.py b/tests/unit/cli/commands/build_utils/test_manifest.py index 76b1de74..9ec0d06e 100644 --- a/tests/unit/cli/commands/build_utils/test_manifest.py +++ b/tests/unit/cli/commands/build_utils/test_manifest.py @@ -214,3 +214,32 @@ def test_manifest_generated_at_timestamp(): assert "generated_at" in manifest assert manifest["generated_at"].endswith("Z") + + +def test_manifest_includes_config_variable(): + """Test that manifest includes config_variable field.""" + functions = [ + RemoteFunctionMetadata( + function_name="health", + module_path="endpoint", + resource_config_name="my-endpoint", + resource_type="LiveLoadBalancer", + is_async=True, + is_class=False, + file_path=Path("endpoint.py"), + http_method="GET", + http_path="/health", + is_load_balanced=True, + is_live_resource=True, + config_variable="gpu_config", + ) + ] + + builder = ManifestBuilder("test-project", functions) + manifest = builder.build() + + assert manifest["resources"]["my-endpoint"]["config_variable"] == "gpu_config" + assert ( + manifest["resources"]["my-endpoint"]["functions"][0]["config_variable"] + == "gpu_config" + ) diff --git a/tests/unit/cli/commands/build_utils/test_scanner.py b/tests/unit/cli/commands/build_utils/test_scanner.py index 32e300e8..9114ec8e 100644 --- a/tests/unit/cli/commands/build_utils/test_scanner.py +++ b/tests/unit/cli/commands/build_utils/test_scanner.py @@ -437,3 +437,29 @@ async def my_function(data): assert len(functions) == 1 # Should preserve special characters in resource name assert functions[0].resource_config_name == "01_gpu-worker.v1" + + +def test_scanner_extracts_config_variable_names(): + """Test that scanner captures config variable names.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + test_file = project_dir / "endpoint.py" + + test_file.write_text( + """ +from tetra_rp import LiveLoadBalancer, remote + +gpu_config = LiveLoadBalancer(name="my-endpoint") + +@remote(gpu_config, method="GET", path="/health") +async def health(): + return {"status": "ok"} +""" + ) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + assert len(functions) == 1 + assert functions[0].config_variable == "gpu_config" + assert functions[0].resource_config_name == "my-endpoint" diff --git a/tests/unit/runtime/test_manifest_client.py b/tests/unit/runtime/test_manifest_client.py index 27bb12cc..a3aba472 100644 --- a/tests/unit/runtime/test_manifest_client.py +++ b/tests/unit/runtime/test_manifest_client.py @@ -36,9 +36,9 @@ def test_init_with_url(self): def test_init_from_env(self): """Test initialization from environment variable.""" - with patch.dict(os.environ, {"FLASH_MOTHERSHIP_URL": "https://from-env.com"}): + with patch.dict(os.environ, {"FLASH_MOTHERSHIP_ID": "mothership123"}): client = ManifestClient() - assert client.mothership_url == "https://from-env.com" + assert client.mothership_url == "https://mothership123.api.runpod.ai" def test_init_missing_url(self): """Test initialization fails without URL.""" @@ -48,7 +48,7 @@ def test_init_missing_url(self): def test_init_explicit_over_env(self): """Test explicit URL takes precedence over env var.""" - with patch.dict(os.environ, {"FLASH_MOTHERSHIP_URL": "https://env.com"}): + with patch.dict(os.environ, {"FLASH_MOTHERSHIP_ID": "env-mothership"}): client = ManifestClient(mothership_url="https://explicit.com") assert client.mothership_url == "https://explicit.com" diff --git a/tests/unit/runtime/test_mothership_provisioner.py b/tests/unit/runtime/test_mothership_provisioner.py new file mode 100644 index 00000000..a4777eec --- /dev/null +++ b/tests/unit/runtime/test_mothership_provisioner.py @@ -0,0 +1,451 @@ +"""Unit tests for mothership provisioner module.""" + +import json +import os +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +from tetra_rp.runtime.mothership_provisioner import ( + ManifestDiff, + compute_resource_hash, + create_resource_from_manifest, + get_mothership_url, + is_mothership, + load_manifest, + reconcile_manifests, +) + + +class TestGetMothershipUrl: + """Tests for get_mothership_url function.""" + + def test_get_mothership_url_from_env_var(self): + """Test constructing mothership URL from RUNPOD_ENDPOINT_ID.""" + with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "test-endpoint-123"}): + url = get_mothership_url() + assert url == "https://test-endpoint-123.api.runpod.ai" + + def test_get_mothership_url_missing_env_var(self): + """Test that RuntimeError is raised when RUNPOD_ENDPOINT_ID is not set.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(RuntimeError, match="RUNPOD_ENDPOINT_ID"): + get_mothership_url() + + def test_get_mothership_url_with_empty_env_var(self): + """Test that RuntimeError is raised when RUNPOD_ENDPOINT_ID is empty.""" + with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": ""}): + with pytest.raises(RuntimeError, match="RUNPOD_ENDPOINT_ID"): + get_mothership_url() + + +class TestIsMothership: + """Tests for is_mothership function.""" + + def test_is_mothership_true(self): + """Test that is_mothership returns True when env var is 'true'.""" + with patch.dict(os.environ, {"FLASH_IS_MOTHERSHIP": "true"}): + assert is_mothership() is True + + def test_is_mothership_true_uppercase(self): + """Test that is_mothership returns True for 'TRUE'.""" + with patch.dict(os.environ, {"FLASH_IS_MOTHERSHIP": "TRUE"}): + assert is_mothership() is True + + def test_is_mothership_true_mixed_case(self): + """Test that is_mothership returns True for 'True'.""" + with patch.dict(os.environ, {"FLASH_IS_MOTHERSHIP": "True"}): + assert is_mothership() is True + + def test_is_mothership_false(self): + """Test that is_mothership returns False when env var is 'false'.""" + with patch.dict(os.environ, {"FLASH_IS_MOTHERSHIP": "false"}): + assert is_mothership() is False + + def test_is_mothership_missing_env_var(self): + """Test that is_mothership returns False when env var is not set.""" + with patch.dict(os.environ, {}, clear=True): + assert is_mothership() is False + + def test_is_mothership_empty_string(self): + """Test that is_mothership returns False for empty string.""" + with patch.dict(os.environ, {"FLASH_IS_MOTHERSHIP": ""}): + assert is_mothership() is False + + def test_is_mothership_invalid_value(self): + """Test that is_mothership returns False for invalid values.""" + with patch.dict(os.environ, {"FLASH_IS_MOTHERSHIP": "yes"}): + assert is_mothership() is False + + +class TestLoadManifest: + """Tests for load_manifest function.""" + + def test_load_manifest_from_explicit_path(self): + """Test loading manifest from explicit path.""" + manifest_data = {"version": "1.0", "resources": {}} + + with tempfile.TemporaryDirectory() as tmpdir: + manifest_path = Path(tmpdir) / "manifest.json" + manifest_path.write_text(json.dumps(manifest_data)) + + result = load_manifest(manifest_path) + assert result == manifest_data + + def test_load_manifest_from_env_var(self): + """Test loading manifest from environment variable.""" + manifest_data = {"version": "1.0", "resources": {}} + + with tempfile.TemporaryDirectory() as tmpdir: + manifest_path = Path(tmpdir) / "flash_manifest.json" + manifest_path.write_text(json.dumps(manifest_data)) + + with patch.dict(os.environ, {"FLASH_MANIFEST_PATH": str(manifest_path)}): + result = load_manifest() + assert result == manifest_data + + def test_load_manifest_not_found(self): + """Test that FileNotFoundError is raised when manifest is not found.""" + with tempfile.TemporaryDirectory() as tmpdir: + with patch("pathlib.Path.cwd", return_value=Path(tmpdir)): + with pytest.raises( + FileNotFoundError, match="flash_manifest.json not found" + ): + load_manifest() + + def test_load_manifest_invalid_json(self): + """Test that FileNotFoundError is raised for invalid JSON.""" + with tempfile.TemporaryDirectory() as tmpdir: + manifest_path = Path(tmpdir) / "flash_manifest.json" + manifest_path.write_text("invalid json {") + + with patch.dict(os.environ, {"FLASH_MANIFEST_PATH": str(manifest_path)}): + # Should continue searching when JSON is invalid + with pytest.raises(FileNotFoundError): + load_manifest() + + def test_load_manifest_searches_multiple_paths(self): + """Test that load_manifest searches multiple paths.""" + manifest_data = {"version": "1.0", "resources": {}} + + with tempfile.TemporaryDirectory() as tmpdir: + # Create manifest in cwd + manifest_path = Path(tmpdir) / "flash_manifest.json" + manifest_path.write_text(json.dumps(manifest_data)) + + with patch("pathlib.Path.cwd", return_value=Path(tmpdir)): + result = load_manifest() + assert result == manifest_data + + +class TestComputeResourceHash: + """Tests for compute_resource_hash function.""" + + def test_compute_resource_hash_basic(self): + """Test computing hash for basic resource data.""" + resource_data = {"name": "test", "type": "ServerlessResource"} + hash_value = compute_resource_hash(resource_data) + + # Verify it's a hex string + assert isinstance(hash_value, str) + assert len(hash_value) == 64 # SHA-256 hex is 64 chars + assert all(c in "0123456789abcdef" for c in hash_value) + + def test_compute_resource_hash_consistent(self): + """Test that same data produces same hash.""" + resource_data = {"name": "test", "type": "ServerlessResource"} + hash1 = compute_resource_hash(resource_data) + hash2 = compute_resource_hash(resource_data) + + assert hash1 == hash2 + + def test_compute_resource_hash_different_data(self): + """Test that different data produces different hashes.""" + data1 = {"name": "test1", "type": "ServerlessResource"} + data2 = {"name": "test2", "type": "ServerlessResource"} + + hash1 = compute_resource_hash(data1) + hash2 = compute_resource_hash(data2) + + assert hash1 != hash2 + + def test_compute_resource_hash_order_independent(self): + """Test that key order doesn't affect hash (JSON sorts keys).""" + data1 = {"name": "test", "type": "ServerlessResource"} + data2 = {"type": "ServerlessResource", "name": "test"} + + hash1 = compute_resource_hash(data1) + hash2 = compute_resource_hash(data2) + + # Should be same because json.dumps with sort_keys=True + assert hash1 == hash2 + + def test_compute_resource_hash_nested_data(self): + """Test computing hash for nested resource data.""" + resource_data = { + "name": "test", + "type": "ServerlessResource", + "config": { + "imageName": "test:latest", + "workers": {"min": 1, "max": 5}, + }, + } + hash_value = compute_resource_hash(resource_data) + + assert isinstance(hash_value, str) + assert len(hash_value) == 64 # SHA-256 hex is 64 chars + + +class TestReconcileManifests: + """Tests for reconcile_manifests function.""" + + def test_reconcile_manifests_empty_both(self): + """Test reconciliation with empty manifests.""" + local = {"resources": {}} + persisted = {"resources": {}} + + result = reconcile_manifests(local, persisted) + + assert isinstance(result, ManifestDiff) + assert result.new == [] + assert result.changed == [] + assert result.removed == [] + assert result.unchanged == [] + + def test_reconcile_manifests_new_resources(self): + """Test detection of new resources.""" + local = { + "resources": { + "worker1": {"resource_type": "ServerlessResource", "data": "v1"} + } + } + persisted = {"resources": {}} + + result = reconcile_manifests(local, persisted) + + assert result.new == ["worker1"] + assert result.changed == [] + assert result.removed == [] + assert result.unchanged == [] + + def test_reconcile_manifests_removed_resources(self): + """Test detection of removed resources.""" + local = {"resources": {}} + persisted = { + "resources": { + "worker1": { + "resource_type": "ServerlessResource", + "config_hash": "abc123", + } + } + } + + result = reconcile_manifests(local, persisted) + + assert result.new == [] + assert result.changed == [] + assert result.removed == ["worker1"] + assert result.unchanged == [] + + def test_reconcile_manifests_changed_resources(self): + """Test detection of changed resources.""" + local_data = {"resource_type": "ServerlessResource", "config": "v2"} + + persisted_data = {"resource_type": "ServerlessResource", "config": "v1"} + persisted_hash = compute_resource_hash(persisted_data) + + local = {"resources": {"worker1": local_data}} + persisted = { + "resources": {"worker1": {**persisted_data, "config_hash": persisted_hash}} + } + + result = reconcile_manifests(local, persisted) + + assert result.new == [] + assert result.changed == ["worker1"] + assert result.removed == [] + assert result.unchanged == [] + + def test_reconcile_manifests_unchanged_resources(self): + """Test detection of unchanged resources.""" + resource_data = {"resource_type": "ServerlessResource", "config": "v1"} + resource_hash = compute_resource_hash(resource_data) + + local = {"resources": {"worker1": resource_data}} + persisted = { + "resources": {"worker1": {**resource_data, "config_hash": resource_hash}} + } + + result = reconcile_manifests(local, persisted) + + assert result.new == [] + assert result.changed == [] + assert result.removed == [] + assert result.unchanged == ["worker1"] + + def test_reconcile_manifests_includes_load_balancer_resources(self): + """Test that LoadBalancer resources are included in provisioning.""" + local = { + "resources": { + "mothership": { + "resource_type": "LoadBalancerSlsResource", + "data": "v1", + }, + "worker1": {"resource_type": "ServerlessResource", "data": "v1"}, + } + } + persisted = {"resources": {}} + + result = reconcile_manifests(local, persisted) + + # LoadBalancer should be in new resources alongside Serverless + assert "mothership" in result.new + assert "worker1" in result.new + + def test_reconcile_manifests_includes_live_load_balancer(self): + """Test that LiveLoadBalancer resources are included in provisioning.""" + local = { + "resources": { + "live_mothership": { + "resource_type": "LiveLoadBalancer", + "data": "v1", + }, + "worker1": {"resource_type": "ServerlessResource", "data": "v1"}, + } + } + persisted = {"resources": {}} + + result = reconcile_manifests(local, persisted) + + # LiveLoadBalancer should be in new resources alongside Serverless + assert "live_mothership" in result.new + assert "worker1" in result.new + + def test_reconcile_manifests_persisted_none(self): + """Test reconciliation when persisted manifest is None (first boot).""" + local = { + "resources": { + "worker1": {"resource_type": "ServerlessResource", "data": "v1"} + } + } + + result = reconcile_manifests(local, None) + + assert result.new == ["worker1"] + assert result.changed == [] + assert result.removed == [] + + def test_reconcile_manifests_mixed_scenario(self): + """Test reconciliation with mixed new, changed, and removed resources.""" + resource_data_1 = {"resource_type": "ServerlessResource", "config": "v1"} + resource_hash_1 = compute_resource_hash(resource_data_1) + + resource_data_2_old = {"resource_type": "ServerlessResource", "config": "v1"} + resource_hash_2_old = compute_resource_hash(resource_data_2_old) + + resource_data_2_new = {"resource_type": "ServerlessResource", "config": "v2"} + + local = { + "resources": { + "new_worker": resource_data_1, + "changed_worker": resource_data_2_new, + "unchanged_worker": resource_data_1, + } + } + + persisted = { + "resources": { + "changed_worker": { + **resource_data_2_old, + "config_hash": resource_hash_2_old, + }, + "unchanged_worker": { + **resource_data_1, + "config_hash": resource_hash_1, + }, + "removed_worker": { + "resource_type": "ServerlessResource", + "config_hash": "old_hash", + }, + } + } + + result = reconcile_manifests(local, persisted) + + assert result.new == ["new_worker"] + assert result.changed == ["changed_worker"] + assert result.removed == ["removed_worker"] + assert "unchanged_worker" in result.unchanged + + +class TestCreateResourceFromManifest: + """Tests for create_resource_from_manifest function.""" + + def test_create_resource_from_manifest_serverless(self): + """Test creating ServerlessResource from manifest.""" + from tetra_rp.core.resources.serverless import ServerlessResource + + resource_name = "worker1" + resource_data = {"resource_type": "ServerlessResource"} + mothership_url = "https://test.api.runpod.ai" + + with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "mothership-123"}): + resource = create_resource_from_manifest( + resource_name, resource_data, mothership_url + ) + + assert isinstance(resource, ServerlessResource) + # ServerlessResource may append "-fb" suffix during initialization + assert resource_name in resource.name + assert resource.env["FLASH_MOTHERSHIP_ID"] == "mothership-123" + assert resource.env["FLASH_RESOURCE_NAME"] == resource_name + + def test_create_resource_from_manifest_live_serverless(self): + """Test that LiveServerless type is accepted but creates ServerlessResource. + + Note: Current implementation creates ServerlessResource regardless of type. + This is a known limitation - manifest needs to include full deployment config + to properly construct different resource types. + """ + from tetra_rp.core.resources.serverless import ServerlessResource + + resource_name = "worker1" + resource_data = {"resource_type": "LiveServerless"} + mothership_url = "https://test.api.runpod.ai" + + with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "mothership-123"}): + # Should not raise - LiveServerless is in supported types + resource = create_resource_from_manifest( + resource_name, resource_data, mothership_url + ) + + # Returns ServerlessResource (current limitation) + assert isinstance(resource, ServerlessResource) + assert resource_name in resource.name + + def test_create_resource_from_manifest_unsupported_type(self): + """Test that ValueError is raised for unsupported resource types.""" + resource_name = "worker1" + resource_data = {"resource_type": "UnsupportedResourceType"} + mothership_url = "https://test.api.runpod.ai" + + with pytest.raises(ValueError, match="Unsupported resource type"): + create_resource_from_manifest(resource_name, resource_data, mothership_url) + + def test_create_resource_from_manifest_default_type(self): + """Test that default type is ServerlessResource when not specified.""" + from tetra_rp.core.resources.serverless import ServerlessResource + + resource_name = "worker1" + resource_data = {} # No resource_type specified + mothership_url = "https://test.api.runpod.ai" + + with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "mothership-123"}): + resource = create_resource_from_manifest( + resource_name, resource_data, mothership_url + ) + + assert isinstance(resource, ServerlessResource) + assert resource_name in resource.name diff --git a/tests/unit/runtime/test_service_registry.py b/tests/unit/runtime/test_service_registry.py index 8dc88aa1..21ef1028 100644 --- a/tests/unit/runtime/test_service_registry.py +++ b/tests/unit/runtime/test_service_registry.py @@ -83,28 +83,36 @@ def test_init_manifest_not_found(self): # Should not fail, returns empty manifest assert registry.get_manifest().function_registry == {} - def test_get_current_endpoint_id(self): - """Test retrieval of current endpoint ID from env.""" - with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu-endpoint-123"}): + def test_get_current_endpoint_id_with_resource_name(self): + """Test retrieval using FLASH_RESOURCE_NAME (child endpoint).""" + with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): + registry = ServiceRegistry(manifest_path=Path("/nonexistent")) + assert registry.get_current_endpoint_id() == "gpu_config" + + def test_get_current_endpoint_id_fallback_to_runpod_id(self): + """Test fallback to RUNPOD_ENDPOINT_ID when FLASH_RESOURCE_NAME not set.""" + with patch.dict( + os.environ, {"RUNPOD_ENDPOINT_ID": "gpu-endpoint-123"}, clear=True + ): registry = ServiceRegistry(manifest_path=Path("/nonexistent")) assert registry.get_current_endpoint_id() == "gpu-endpoint-123" def test_get_current_endpoint_id_not_set(self): - """Test when endpoint ID not set.""" + """Test when neither env var is set.""" with patch.dict(os.environ, {}, clear=True): registry = ServiceRegistry(manifest_path=Path("/nonexistent")) assert registry.get_current_endpoint_id() is None def test_is_local_function_local(self, manifest_file): - """Test determining local function.""" - with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu_config"}): + """Test determining local function using FLASH_RESOURCE_NAME.""" + with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): registry = ServiceRegistry(manifest_path=manifest_file) assert registry.is_local_function("gpu_task") is True assert registry.is_local_function("inference") is True def test_is_local_function_remote(self, manifest_file): """Test determining remote function (with directory loaded).""" - with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu_config"}): + with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): mock_client = AsyncMock() mock_client.get_directory.return_value = { "cpu_config": "https://cpu.example.com" @@ -125,15 +133,15 @@ def test_is_local_function_not_in_manifest(self, manifest_file): assert registry.is_local_function("unknown_function") is True def test_get_endpoint_for_function_local(self, manifest_file): - """Test getting endpoint for local function.""" - with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu_config"}): + """Test getting endpoint for local function using FLASH_RESOURCE_NAME.""" + with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): registry = ServiceRegistry(manifest_path=manifest_file) endpoint = registry.get_endpoint_for_function("gpu_task") assert endpoint is None # Local returns None def test_get_endpoint_for_function_remote_no_directory(self, manifest_file): """Test getting endpoint for remote function without directory.""" - with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu_config"}): + with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): registry = ServiceRegistry(manifest_path=manifest_file) # CPU function is remote, but no directory loaded endpoint = registry.get_endpoint_for_function("preprocess") @@ -146,8 +154,8 @@ def test_get_endpoint_for_function_not_in_manifest(self, manifest_file): registry.get_endpoint_for_function("unknown_function") def test_get_resource_for_function_local(self, manifest_file): - """Test getting ServerlessResource for local function.""" - with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu_config"}): + """Test getting ServerlessResource for local function using FLASH_RESOURCE_NAME.""" + with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): registry = ServiceRegistry(manifest_path=manifest_file) resource = registry.get_resource_for_function("gpu_task") # Local function returns None @@ -155,7 +163,7 @@ def test_get_resource_for_function_local(self, manifest_file): def test_get_resource_for_function_remote(self, manifest_file): """Test getting ServerlessResource for remote function.""" - with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu_config"}): + with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): mock_client = AsyncMock() mock_client.get_directory.return_value = { "cpu_config": "https://api.runpod.io/v2/abc123" diff --git a/uv.lock b/uv.lock index 32ecc49b..773f9688 100644 --- a/uv.lock +++ b/uv.lock @@ -2916,6 +2916,7 @@ dependencies = [ { name = "questionary" }, { name = "rich" }, { name = "runpod" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typer" }, ] @@ -2943,6 +2944,7 @@ requires-dist = [ { name = "questionary", specifier = ">=2.0.0" }, { name = "rich", specifier = ">=14.0.0" }, { name = "runpod", git = "https://github.com/runpod/runpod-python?rev=main" }, + { name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.0.0" }, { name = "typer", specifier = ">=0.12.0" }, ]