From 85a20ba74d8b0c730711171badda13f2c13ed4f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 22 Jan 2026 00:36:05 -0800 Subject: [PATCH 01/11] docs: Flash Deploy Guide --- docs/Flash_Deploy_Guide.md | 1341 ++++++++++++++++++++++++++++++++++++ 1 file changed, 1341 insertions(+) create mode 100644 docs/Flash_Deploy_Guide.md diff --git a/docs/Flash_Deploy_Guide.md b/docs/Flash_Deploy_Guide.md new file mode 100644 index 00000000..b09dab14 --- /dev/null +++ b/docs/Flash_Deploy_Guide.md @@ -0,0 +1,1341 @@ +# Flash Deploy Guide + +## Overview + +Flash Deploy is a distributed runtime system that enables scalable execution of `@remote` functions across dynamically provisioned RunPod serverless endpoints. It bridges the gap between local development and production cloud deployment through a unified interface. + +### System Goals + +1. **Transparency**: Developers write local Python, deploy to cloud without code changes +2. **Scalability**: Functions execute on remote serverless endpoints with resource isolation +3. **Flexibility**: Support both queue-based and load-balanced execution models +4. **Reliability**: Automatic resource provisioning, state reconciliation, and drift detection + +### High-Level Architecture + +```mermaid +graph TB + Developer["Developer Machine"] + + subgraph Build["Build Phase"] + Scan["Scanner
Find @remote"] + Gen["Generator
Create Handlers"] + Manifest["ManifestBuilder
flash_manifest.json"] + end + + subgraph Cloud["RunPod Cloud"] + S3["S3 Storage
archive.tar.gz"] + + subgraph Mothership["Mothership Endpoint
(FLASH_IS_MOTHERSHIP=true)"] + MothershipProvisioner["MothershipsProvisioner
Reconcile & Deploy"] + MothershipRegistry["Manifest Cache
Function Registry"] + end + + subgraph ChildEndpoints["Child Endpoints
(Resource Configs)"] + Handler1["GPU Handler
@remote functions"] + Handler2["CPU Handler
@remote functions"] + LocalRegistry["Service Registry
Cross-endpoint routing"] + end + end + + Database["RunPod State Manager
GraphQL API"] + + Developer -->|flash build| Build + Build -->|archive| S3 + Developer -->|flash deploy send| S3 + S3 -->|download| Mothership + Mothership -->|reconcile| ChildEndpoints + MothershipProvisioner -->|update state| Database + ChildEndpoints -->|query manifest| Database + Developer -->|call @remote| ChildEndpoints + + style Mothership fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style ChildEndpoints fill:#388e3c,stroke:#1b5e20,stroke-width:3px,color:#fff + style Build fill:#f57c00,stroke:#e65100,stroke-width:3px,color:#fff +``` + +### Key Concepts + +**Mothership**: The orchestration endpoint responsible for deployment, resource provisioning, and manifest distribution. Created via `flash deploy new `. + +**Child Endpoints**: Worker endpoints that execute `@remote` functions. One per resource config (e.g., `gpu_config`, `cpu_config`). + +**Manifest**: JSON document describing all deployed functions, their resource configs, routing rules, and metadata. Built at compile-time, distributed to all endpoints. + +**Resource Config**: A Python object that defines CloudResource specifications (GPU type, memory, image, etc.). Becomes a deployable endpoint. + +**Service Registry**: Runtime component that maps function names to endpoint URLs and determines local vs remote execution. + +--- + +## CLI Commands Reference + +### flash deploy new + +Create a new deployment environment (mothership). + +```bash +flash deploy new [--app-name ] +``` + +**Arguments:** +- `env_name`: Name for the deployment environment + +**Options:** +- `--app-name `: Flash app name (auto-detected if not provided) + +**What it does:** +1. Creates a FlashApp in RunPod (if first environment for the app) +2. Creates FlashEnvironment with the specified name +3. Provisions a mothership serverless endpoint + +**Example:** +```bash +flash deploy new production +# Output: Environment 'production' created successfully +# Environment ID: flash-prod-abc123 +# Next: flash deploy send production +``` + +**Implementation:** `src/tetra_rp/cli/commands/deploy.py:38-50` + +--- + +### flash deploy send + +Deploy built archive to an environment. + +```bash +flash deploy send [--app-name ] +``` + +**Arguments:** +- `env_name`: Name of the deployment environment + +**Options:** +- `--app-name `: Flash app name (auto-detected if not provided) + +**Prerequisites:** +- Archive must exist at `.flash/archive.tar.gz` (created by `flash build`) + +**What it does:** +1. Uploads archive.tar.gz to S3 +2. Notifies mothership endpoint to download and extract +3. Triggers auto-provisioning on mothership boot + +**Example:** +```bash +flash deploy send production +# Output: 🚀 Deploying to 'production'... +# Deployment Complete +``` + +**Implementation:** `src/tetra_rp/cli/commands/deploy.py:197-224` + +--- + +### flash deploy list + +List all deployment environments for an app. + +```bash +flash deploy list [--app-name ] +``` + +**Options:** +- `--app-name `: Flash app name (auto-detected if not provided) + +**Output:** Table showing: +- Environment name +- Environment ID +- Active build ID +- Creation timestamp + +**Implementation:** `src/tetra_rp/cli/commands/deploy.py:27-135` + +--- + +### flash deploy info + +Show detailed information about a deployment environment. + +```bash +flash deploy info [--app-name ] +``` + +**Arguments:** +- `env_name`: Name of the deployment environment + +**Options:** +- `--app-name `: Flash app name (auto-detected if not provided) + +**Output:** Displays: +- Environment status and ID +- Active build ID +- Associated endpoints +- Associated network volumes + +**Implementation:** `src/tetra_rp/cli/commands/deploy.py:69-111` + +--- + +### flash deploy delete + +Delete a deployment environment. + +```bash +flash deploy delete [--app-name ] +``` + +**Arguments:** +- `env_name`: Name of the deployment environment + +**Options:** +- `--app-name `: Flash app name (auto-detected if not provided) + +**Safety:** +- Requires confirmation (twice for safety) +- Cannot be undone + +**Implementation:** `src/tetra_rp/cli/commands/deploy.py:237-270` + +--- + +## Complete Deployment Flow + +This section walks through the entire journey from source code to executing remote functions. + +### Phase 1: Build (Local Development) + +```mermaid +sequenceDiagram + Developer->>Build: flash build + Build->>Build: Scan files for @remote + Build->>Build: Find resource configs
(e.g., gpu_config, cpu_config) + Build->>Build: Generate handler_gpu_config.py
handler_cpu_config.py + Build->>Build: Scan functions per resource
Build function registry + Build->>ManifestBuilder: Create manifest entry
per resource config + ManifestBuilder->>ManifestBuilder: Validate routes
(no conflicts) + ManifestBuilder->>ManifestBuilder: Detect load-balanced
vs queue-based + ManifestBuilder->>Manifest: Write flash_manifest.json + Build->>TarGz: Package build directory + TarGz->>Archive: Create .flash/archive.tar.gz + Archive->>Developer: Build complete +``` + +**Scanner** (`src/tetra_rp/cli/commands/build_utils/scanner.py`): +- Decorators scanned: `@remote`, `@load_balanced`, `@cluster` +- Extracts: function name, module path, async status, HTTP routing info +- Groups functions by resource config + +**Handler Generation** (`src/tetra_rp/cli/commands/build_utils/handler_generator.py`): +- Per-resource handlers: `handler_{resource_name}.py` +- Contains `handle(job_input)` function for RunPod +- Invokes discovered `@remote` functions + +**Manifest Building** (`src/tetra_rp/cli/commands/build_utils/manifest.py`): +- Structure: + ```json + { + "version": "1.0", + "generated_at": "2024-01-21T10:00:00Z", + "project_name": "my_project", + "resources": { + "gpu_config": { + "resource_type": "LiveServerless", + "handler_file": "handler_gpu_config.py", + "functions": [{"name": "process", "module": "main", ...}], + "is_load_balanced": false + } + }, + "function_registry": {"process": "gpu_config"}, + "routes": {} + } + ``` + +**Archive Structure**: +``` +archive.tar.gz +├── flash_manifest.json # Manifest (source of truth) +├── src/ # Application source code +├── handler_gpu_config.py # Generated handlers +├── handler_cpu_config.py +└── vendor/ # Bundled dependencies +``` + +**Key Files:** +- `src/tetra_rp/cli/commands/build.py` - Entry point for `flash build` +- `src/tetra_rp/cli/commands/build_utils/manifest.py` - ManifestBuilder + +--- + +### Phase 2: Deploy to Environment + +```mermaid +sequenceDiagram + Developer->>CLI: flash deploy send + CLI->>S3: Upload .flash/archive.tar.gz + CLI->>RunPod: POST /run
with archive URL + RunPod->>Mothership: Boot mothership endpoint + Mothership->>Mothership: Download archive from URL + Mothership->>Mothership: Extract to working dir + Mothership->>Mothership: Load flash_manifest.json + Mothership->>MothershipsProvisioner: Trigger reconciliation +``` + +**Upload Process** (`src/tetra_rp/cli/commands/deploy.py:197-224`): +1. Archive uploaded to RunPod's built-in S3 storage +2. URL generated with temporary access +3. URL passed to mothership endpoint creation + +**Key Files:** +- `src/tetra_rp/cli/commands/deploy.py` - Deploy CLI commands + +--- + +### Phase 3: Mothership Boot & Auto-Provisioning + +The mothership runs on each boot to reconcile desired state (manifest) with current state (local resources). + +```mermaid +sequenceDiagram + RunPod->>Mothership: Boot endpoint + Mothership->>Mothership: Initialize runtime + Mothership->>ManifestProvisioner: Load manifest + ManifestProvisioner->>ManifestProvisioner: Read flash_manifest.json + ManifestProvisioner->>StateManager: Fetch persisted state + StateManager->>GraphQL: Query activeBuildId →
manifest + GraphQL->>StateManager: Return persisted manifest + ManifestProvisioner->>ManifestProvisioner: Compute diff:
new, changed, removed + ManifestProvisioner->>ResourceManager: Deploy resources + ResourceManager->>RunPod: Create GPU endpoint
(gpu_config) + RunPod->>GPU: Boot with handler + GPU->>Mothership: Register endpoint + ResourceManager->>RunPod: Create CPU endpoint
(cpu_config) + RunPod->>CPU: Boot with handler + CPU->>Mothership: Register endpoint + ManifestProvisioner->>StateManager: Update state
with endpoints + StateManager->>GraphQL: Mutation:
updateFlashBuildManifest +``` + +**Key Components:** + +**MothershipsProvisioner** (`src/tetra_rp/runtime/mothership_provisioner.py`): +- `is_mothership()`: Check if endpoint is mothership (FLASH_IS_MOTHERSHIP=true) +- `reconcile_manifest()`: Compute diff between desired and current state +- Orchestrates resource creation via ResourceManager +- Updates State Manager with deployed resources + +**ResourceManager** (`src/tetra_rp/core/resources/resource_manager.py`): +- Singleton pattern (global resource registry) +- Stores state in `.runpod/resources.pkl` with file locking +- Tracks config hashes for drift detection (hash comparison) +- Supports parallel deployment (max 3 concurrent, configurable) +- Auto-migrates legacy resources + +**StateManagerClient** (`src/tetra_rp/runtime/state_manager_client.py`): +- GraphQL client for persisting manifest state +- Read-modify-write pattern for updates (3 GQL roundtrips) +- Thread-safe with asyncio.Lock for concurrent updates +- Retries with exponential backoff (3 attempts) + +**Reconciliation Logic**: +1. **Fetch persisted manifest**: Query State Manager for previous deployment state +2. **Compare with current manifest**: Detect new, changed, and removed resources +3. **Deploy new resources**: Create endpoints for new resource configs +4. **Update changed resources**: Apply configuration changes if hash differs +5. **Remove deleted resources**: Decommission endpoints no longer in manifest +6. **Persist new state**: Update State Manager with current deployment + +**Key Files:** +- `src/tetra_rp/runtime/mothership_provisioner.py` - Reconciliation logic +- `src/tetra_rp/core/resources/resource_manager.py` - Resource provisioning +- `src/tetra_rp/runtime/state_manager_client.py` - State persistence + +--- + +### Phase 4: Child Endpoint Initialization + +Each child endpoint boots independently and prepares for function execution. + +```mermaid +sequenceDiagram + RunPod->>Child: Boot with handler_gpu_config.py + Child->>Child: Initialize runtime + Child->>ManifestFetcher: Load manifest + ManifestFetcher->>ManifestFetcher: Check cache
(TTL: 300s) + alt Cache expired + ManifestFetcher->>RunPod: Query GraphQL API
getManifest() + RunPod->>ManifestFetcher: Return manifest + else Cache valid + ManifestFetcher->>ManifestFetcher: Return cached + end + ManifestFetcher->>Child: Manifest loaded + Child->>ServiceRegistry: Load manifest + ServiceRegistry->>ServiceRegistry: Build function_registry + ServiceRegistry->>ServiceRegistry: Build resource_mapping + Child->>Mothership: GET /manifest
Get endpoint registry + Mothership->>Child: Return {resource_name: url} + Child->>ServiceRegistry: Cache endpoint URLs + Child->>Ready: Ready to execute functions +``` + +**ManifestFetcher** (`src/tetra_rp/runtime/manifest_fetcher.py`): +- Caches manifest with TTL (default: 300s) +- Fetches from RunPod GraphQL API (source of truth) +- Falls back to local flash_manifest.json if API unavailable +- Updates local file with fetched data +- Thread-safe with asyncio.Lock + +**ServiceRegistry** (`src/tetra_rp/runtime/service_registry.py`): +- Loads manifest to build function registry +- Queries mothership for endpoint URLs: `GET /manifest` +- Returns mapping: `{resource_config_name: endpoint_url}` +- Determines local vs remote function calls: + - Local: Function's resource config matches FLASH_RESOURCE_NAME env var + - Remote: Query ServiceRegistry for endpoint URL + +**Key Environment Variables**: +- `FLASH_RESOURCE_NAME`: This endpoint's resource config name (e.g., "gpu_config") +- `FLASH_MOTHERSHIP_ID`: Mothership endpoint ID for manifest queries +- `FLASH_MANIFEST_PATH`: Optional override for manifest location +- `RUNPOD_ENDPOINT_ID`: This endpoint's RunPod endpoint ID + +**Key Files:** +- `src/tetra_rp/runtime/manifest_fetcher.py` - Manifest loading with caching +- `src/tetra_rp/runtime/service_registry.py` - Service discovery +- `src/tetra_rp/runtime/generic_handler.py` - Handler utilities + +--- + +### Phase 5: Runtime Function Execution + +When client calls `@remote function`: + +```mermaid +sequenceDiagram + Client->>Stub: Call remote_function(x, y) + Stub->>Stub: Serialize args/kwargs
cloudpickle + base64 + Stub->>ServiceRegistry: Lookup resource_config + ServiceRegistry->>ServiceRegistry: Check if local
or remote + alt Local endpoint + Stub->>Handler: Direct call + Handler->>Function: Execute + Handler->>Stub: Return result + else Remote endpoint + Stub->>Endpoint: POST /execute
{function, args, kwargs} + Endpoint->>Handler: Route to handler + Handler->>Function: Execute remote + Function->>Handler: Return result + Handler->>Endpoint: Serialize result + Endpoint->>Stub: Return serialized result + Stub->>Stub: Deserialize result + end + Stub->>Client: Return unwrapped result +``` + +**Serialization** (`src/tetra_rp/runtime/serialization.py`): +- **Args/Kwargs**: cloudpickle → base64 +- **Result**: cloudpickle → base64 +- Max payload size: 10MB + +**Handler Routing**: + +**Queue-Based** (`src/tetra_rp/runtime/generic_handler.py`): + +Uses a factory function `create_handler(function_registry)` that returns a RunPod-compatible handler: + +```python +def handler(job: Dict[str, Any]) -> Dict[str, Any]: + """RunPod serverless handler. + + Args: + job: RunPod job dict with 'input' key + + Returns: + Response dict with 'success', 'result'/'error' keys + """ + job_input = job.get("input", {}) + function_name = job_input.get("function_name") + execution_type = job_input.get("execution_type", "function") + + # Look up and execute function/class + func_or_class = function_registry[function_name] + args, kwargs = deserialize_arguments(job_input) + result = execute_function(func_or_class, args, kwargs, execution_type, job_input) + + return { + "success": True, + "result": serialize_result(result), + } +``` + +**Load-Balanced** (`src/tetra_rp/runtime/lb_handler.py`): +- FastAPI app with user-defined HTTP routes +- `/execute` endpoint for @remote execution (LiveLoadBalancer only) +- `/manifest` endpoint for mothership service discovery +- User routes: HTTP methods + paths from manifest + +**Key Files:** +- `src/tetra_rp/runtime/generic_handler.py` - Queue-based handler +- `src/tetra_rp/runtime/lb_handler.py` - Load-balanced handler factory +- `src/tetra_rp/runtime/serialization.py` - cloudpickle serialization +- `src/tetra_rp/runtime/service_registry.py` - Cross-endpoint routing + +--- + +## Manifest System + +The manifest is the contract between build-time and runtime. It defines all deployable functions, their resource configs, and routing rules. + +### Build-Time: Generation + +**Location**: Generated during `flash build` + +**Builder**: `ManifestBuilder` in `src/tetra_rp/cli/commands/build_utils/manifest.py` + +**Input**: +- List of discovered `@remote` functions (from scanner) +- Each function has: + - Name, module, async status + - Resource config name + - HTTP routing info (for load-balanced) + +**Output Structure**: +```json +{ + "version": "1.0", + "generated_at": "2024-01-21T10:00:00.000Z", + "project_name": "my_app", + "resources": { + "gpu_config": { + "resource_type": "LiveServerless", + "handler_file": "handler_gpu_config.py", + "functions": [ + { + "name": "train", + "module": "ml.training", + "is_async": false, + "is_class": false, + "is_load_balanced": false, + "is_live_resource": false, + "config_variable": "gpu_config" + } + ], + "is_load_balanced": false, + "is_live_resource": false, + "config_variable": "gpu_config" + } + }, + "function_registry": { + "train": "gpu_config" + }, + "routes": {} +} +``` + +**Validation**: +- No duplicate function names across resources +- No duplicate routes (method + path conflicts) +- Load-balanced endpoints have method and path +- No reserved paths (/execute, /ping, /manifest) + +**Code Reference**: `src/tetra_rp/cli/commands/build_utils/manifest.py:50-164` + +--- + +### Runtime: Distribution & Caching + +**Mothership Side** - `ManifestFetcher`: + +1. **Check cache**: Is manifest cached and TTL valid? + - Cache TTL: 300 seconds (configurable) + - Thread-safe with asyncio.Lock + +2. **Fetch from source**: If cache expired + - Primary: RunPod GraphQL API (via RunpodGraphQLClient) + - Fallback: Local flash_manifest.json file + +3. **Update local file**: Persist fetched manifest + - Allows recovery if API unavailable + +4. **Return to caller**: Cached manifest + +**Code Reference**: `src/tetra_rp/runtime/manifest_fetcher.py:47-118` + +**Child Endpoint Side** - `ServiceRegistry`: + +1. **Load manifest**: From local file + - Searches multiple locations (cwd, module dir, etc) + - Fallback to empty manifest if not found + +2. **Build function registry**: Map function_name → resource_config + - Used to determine local vs remote execution + +3. **Query mothership**: Get endpoint URLs + - Endpoint: `GET https://{mothership_id}.api.runpod.ai/manifest` + - Returns: `{"resource_name": "https://endpoint.api.runpod.ai"}` + - Retries with exponential backoff + +4. **Cache endpoints**: Store for routing decisions + +**Code Reference**: `src/tetra_rp/runtime/service_registry.py:29-80` + +--- + +### State Persistence: StateManagerClient + +The State Manager persists manifest state in RunPod's infrastructure, enabling: +- Mothership boot consistency +- Cross-boot resource tracking +- Manifest reconciliation + +**Architecture**: +``` +Client Update Request + ↓ +StateManagerClient (with asyncio.Lock) + ↓ +Read: Query GraphQL for activeBuildId + ↓ +Fetch: Get current manifest from build + ↓ +Modify: Merge changes into manifest + ↓ +Write: Mutation updateFlashBuildManifest +``` + +**GraphQL Operations**: +- Query: `getFlashEnvironment(flashEnvironmentId) → activeBuildId` +- Query: `getFlashBuild(buildId) → manifest` +- Mutation: `updateFlashBuildManifest(buildId, manifest)` + +**Thread Safety**: +- `asyncio.Lock` serializes read-modify-write +- Prevents race conditions during concurrent resource updates +- Max retries: 3 (default), with exponential backoff + +**Performance**: Each update = 3 GQL roundtrips +- Consider batching when provisioning many resources + +**Code Reference**: `src/tetra_rp/runtime/state_manager_client.py:53-248` + +--- + +## Resource Provisioning + +Resources are dynamically provisioned by the mothership during boot, based on the manifest. + +### ResourceManager: Local State + +**Singleton Pattern**: Global resource registry, one per process + +**Persistent Storage**: `.runpod/resources.pkl` +- Format: `(Dict[str, DeployableResource], Dict[str, str])` (resources, config_hashes) +- Concurrent access protected by file locking (cross-platform) +- Pickled with cloudpickle for serialization + +**Operations**: + +**Load**: Read persisted resources on initialization +- Shared lock (non-exclusive) for concurrent reads +- Auto-migrates from legacy hash-based keys to name-based keys + +**Save**: Persist state to disk +- Exclusive lock for writing +- Atomic write to prevent corruption + +**Drift Detection**: Config hash comparison +- Each resource has `config_hash` (computed from resource config) +- Stored hash vs current hash: + - Match: Resource unchanged, skip update + - Mismatch: Resource modified, trigger update + +**Auto-Migration**: Legacy format conversion +- Old format: `{resource_id_hash: resource}` +- New format: `{ResourceType:name: resource}` +- Enables name-based lookup and drift detection + +**Code Reference**: `src/tetra_rp/core/resources/resource_manager.py:22-150` + +--- + +### Deployment Orchestration + +**MothershipsProvisioner** reconciles manifest with local state: + +```python +# 1. Load manifest from flash_manifest.json +manifest = load_manifest() + +# 2. Fetch persisted state from State Manager +persisted = await StateManagerClient.get_persisted_manifest(mothership_id) + +# 3. Compute diff +diff = compute_manifest_diff(manifest, persisted) +# diff.new: Resources to deploy +# diff.changed: Resources to update (hash differs) +# diff.removed: Resources to delete + +# 4. Deploy resources (parallel, max 3 concurrent) +for resource_config in diff.new: + deploy_resource(resource_config) + +# 5. Update changed resources +for resource_config in diff.changed: + update_resource(resource_config) + +# 6. Remove deleted resources +for resource_config in diff.removed: + delete_resource(resource_config) + +# 7. Persist new state +await StateManagerClient.update_resource_state(mothership_id, resources) +``` + +**Parallel Deployment**: +- Max concurrent: 3 (configurable) +- Reduces boot time for deployments with many resources + +**Hash-Based Drift Detection**: +- Stored hash (from previous boot) vs current hash (computed from config) +- If hashes differ: Resource has been modified, trigger update +- Prevents unnecessary updates when resource unchanged + +**Code Reference**: `src/tetra_rp/runtime/mothership_provisioner.py:1-150` + +--- + +## Remote Execution + +When `@remote function` is called, the client determines whether to execute locally or remotely. + +### Execution Modes + +**Queue-Based (Serverless Hub)**: +- No HTTP routing +- Single handler function: `handle(job_input)` +- Input serialized in payload +- Example: `LiveServerless` resource + +**Load-Balanced (FastAPI)**: +- HTTP routing with user-defined routes +- `/execute` endpoint for framework use +- `/manifest` endpoint for service discovery +- Example: `LiveLoadBalancer` resource + +### FunctionRequest/FunctionResponse Protocol + +**Request** (JSON): +```json +{ + "function": "train", + "args": ["base64_cloudpickle_args"], + "kwargs": {"base64": "cloudpickle_kwargs"}, + "timeout": 3600 +} +``` + +**Response** (JSON): +```json +{ + "result": "base64_cloudpickle_result", + "error": null, + "status": "completed" +} +``` + +### Serialization + +**Engine**: cloudpickle (handles lambdas, closures, custom classes) + +**Encoding**: base64 (JSON-safe) + +**Max Size**: 10MB (enforced at serialization) + +**Process**: +```python +# Serialize +import cloudpickle, base64 +serialized = base64.b64encode(cloudpickle.dumps(args)) + +# Deserialize +deserialized = cloudpickle.loads(base64.b64decode(serialized)) +``` + +**Code Reference**: `src/tetra_rp/runtime/serialization.py` + +--- + +### Handler Runtime + +**Generic Handler** (Queue-Based): + +Uses a factory function `create_handler(function_registry)` that creates a RunPod-compatible handler: + +```python +# src/tetra_rp/runtime/generic_handler.py - conceptual flow +def handler(job: Dict[str, Any]) -> Dict[str, Any]: + # Extract job input + job_input = job.get("input", {}) + + # Extract function name and execution type + function_name = job_input.get("function_name") + execution_type = job_input.get("execution_type", "function") + + # Deserialize arguments from base64-encoded cloudpickle + args, kwargs = deserialize_arguments(job_input) + + # Lookup and execute function/class from registry + func_or_class = function_registry[function_name] + result = execute_function(func_or_class, args, kwargs, execution_type, job_input) + + # Serialize result to base64-encoded cloudpickle + return { + "success": True, + "result": serialize_result(result), + } +``` + +**Load-Balanced Handler** (FastAPI): + +Uses `create_lb_handler(route_registry, include_execute=False)` factory: + +```python +# src/tetra_rp/runtime/lb_handler.py - conceptual structure +app = FastAPI() + +# User-defined routes registered from route_registry +for (method, path), handler_func in route_registry.items(): + app.add_api_route(path, handler_func, methods=[method]) + +# Framework /execute endpoint (LiveLoadBalancer only) +@app.post("/execute") +async def execute_remote_function(request: Request) -> Dict[str, Any]: + body = await request.json() + + # Extract function metadata + function_name = body.get("function_name") + function_code = body.get("function_code") + + # Deserialize arguments + args = deserialize_args(body.get("args", [])) + kwargs = deserialize_kwargs(body.get("kwargs", {})) + + # Execute function code in isolated namespace + namespace = {} + exec(function_code, namespace) + func = namespace[function_name] + result = func(*args, **kwargs) + + # Return serialized result + return { + "success": True, + "result": serialize_result(result), + } + +# Manifest service discovery endpoint +@app.get("/manifest") +async def get_manifest() -> Dict[str, Any]: + fetcher = ManifestFetcher() + return await fetcher.get_manifest() +``` + +**Code References**: +- `src/tetra_rp/runtime/generic_handler.py` - Queue-based handler +- `src/tetra_rp/runtime/lb_handler.py` - Load-balanced handler + +--- + +### Cross-Endpoint Routing + +**ServiceRegistry** determines function endpoint: + +```python +# src/tetra_rp/runtime/service_registry.py +registry = ServiceRegistry() + +# Lookup function's resource config +resource_config_name = registry.function_to_resource(function_name) + +# Check if local or remote +local_resource_name = os.getenv("FLASH_RESOURCE_NAME") +if resource_config_name == local_resource_name: + # Local: execute directly + return execute_locally(function_name, args, kwargs) +else: + # Remote: fetch endpoint URL + endpoint_url = registry.get_endpoint_url(resource_config_name) + return execute_remotely(endpoint_url, function_name, args, kwargs) +``` + +**Endpoint URL Caching**: +- Queries mothership: `GET /manifest` +- Caches with TTL (default: 300s) +- Retries with exponential backoff if query fails + +--- + +## Architecture Diagrams + +### System Component Diagram + +```mermaid +graph TB + subgraph Build["Build (Local)"] + Scanner["Scanner
RemoteDecoratorScanner"] + Generator["Generator
HandlerGenerator"] + LBGen["LB Generator
LBHandlerGenerator"] + ManifestB["ManifestBuilder"] + end + + subgraph Packaging["Packaging"] + Archive["Archive Builder
(tar.gz)"] + end + + subgraph Upload["Upload"] + S3["S3 Storage"] + end + + subgraph MothershipBoot["Mothership Boot"] + Fetcher["ManifestFetcher
(cache + GQL)"] + MProvisioner["MothershipsProvisioner
(reconciliation)"] + ResMgr["ResourceManager
(state)"] + StateMgr["StateManagerClient
(persistence)"] + end + + subgraph ChildBoot["Child Endpoint Boot"] + ChildFetcher["ManifestFetcher
(local file)"] + Registry["ServiceRegistry
(function mapping)"] + ManifestC["ManifestClient
(query mothership)"] + end + + subgraph Runtime["Runtime Execution"] + Handler["Handler
(generic or LB)"] + Serial["Serialization
(cloudpickle)"] + Exec["Function Execution"] + end + + Scanner --> Generator + Scanner --> LBGen + Generator --> ManifestB + LBGen --> ManifestB + ManifestB --> Archive + Archive --> S3 + S3 --> Fetcher + Fetcher --> MProvisioner + MProvisioner --> ResMgr + ResMgr --> StateMgr + StateMgr -->|update| S3 + ChildFetcher --> Registry + ManifestC -->|query| Fetcher + Registry --> ManifestC + Handler --> Serial + Serial --> Exec + + style Build fill:#f57c00,stroke:#e65100,stroke-width:3px,color:#fff + style MothershipBoot fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style ChildBoot fill:#388e3c,stroke:#1b5e20,stroke-width:3px,color:#fff + style Runtime fill:#7b1fa2,stroke:#4a148c,stroke-width:3px,color:#fff +``` + +--- + +### Manifest Lifecycle + +```mermaid +graph LR + A["Build Time
ManifestBuilder"] -->|Generate| B["flash_manifest.json
(embedded in archive)"] + B -->|Upload| C["S3
(archive.tar.gz)"] + C -->|Download| D["Mothership
(extract archive)"] + D -->|Load & Query
RunPod GQL| E["ManifestFetcher
(cache)"] + D -->|Update State| F["StateManager
(GraphQL API)"] + G["Child Endpoint
(boot)"] -->|Load from
local file| H["LocalManifest
(from archive)"] + H -->|Build registry| I["ServiceRegistry
(function mapping)"] + I -->|Query| J["Mothership
GET /manifest"] + J -->|Return endpoints| I + I -->|Route calls| K["Handler
(execute)"] + + style A fill:#f57c00,stroke:#e65100,stroke-width:2px,color:#fff + style B fill:#ff6f00,stroke:#e65100,stroke-width:2px,color:#fff + style C fill:#ff6f00,stroke:#e65100,stroke-width:2px,color:#fff + style D fill:#1976d2,stroke:#0d47a1,stroke-width:2px,color:#fff + style E fill:#1976d2,stroke:#0d47a1,stroke-width:2px,color:#fff + style F fill:#0d47a1,stroke:#051c66,stroke-width:2px,color:#fff + style G fill:#388e3c,stroke:#1b5e20,stroke-width:2px,color:#fff + style H fill:#388e3c,stroke:#1b5e20,stroke-width:2px,color:#fff + style I fill:#388e3c,stroke:#1b5e20,stroke-width:2px,color:#fff + style K fill:#388e3c,stroke:#1b5e20,stroke-width:2px,color:#fff +``` + +--- + +### Resource Provisioning State Machine + +```mermaid +graph LR + A["Mothership Boots"] -->|Load manifest| B["Desired State"] + B -->|Fetch persisted| C["Current State"] + C -->|Compute diff| D{"Reconciliation"} + D -->|new| E["Create Resource"] + E -->|Deploy| F["Running"] + F -->|Update state| G["State Manager"] + D -->|changed| H["Update Resource"] + H -->|Hash differs| F + H -->|Hash same| I["Skip"] + D -->|removed| J["Delete Resource"] + J -->|Decommission| K["Deleted"] + K -->|Remove state| G + G -->|On next boot| C + + style A fill:#1976d2,stroke:#0d47a1,stroke-width:2px,color:#fff + style B fill:#1976d2,stroke:#0d47a1,stroke-width:2px,color:#fff + style C fill:#1976d2,stroke:#0d47a1,stroke-width:2px,color:#fff + style D fill:#ff6f00,stroke:#e65100,stroke-width:3px,color:#fff + style E fill:#388e3c,stroke:#1b5e20,stroke-width:2px,color:#fff + style F fill:#388e3c,stroke:#1b5e20,stroke-width:2px,color:#fff + style G fill:#0d47a1,stroke:#051c66,stroke-width:2px,color:#fff +``` + +--- + +## Environment Variables Reference + +### Mothership Configuration + +**FLASH_IS_MOTHERSHIP** (Required on mothership) +- Value: `"true"` +- Enables mothership auto-provisioning logic +- Triggers manifest reconciliation on boot + +**RUNPOD_ENDPOINT_ID** (Required on mothership) +- RunPod serverless endpoint ID +- Used to construct mothership URL: `https://{RUNPOD_ENDPOINT_ID}.api.runpod.ai` +- Set automatically by RunPod platform + +**RUNPOD_API_KEY** (Required for State Manager) +- RunPod API authentication token +- Used by StateManagerClient for GraphQL queries +- Enables manifest persistence + +### Child Endpoint Configuration + +**FLASH_RESOURCE_NAME** (Required on child endpoints) +- Resource config name (e.g., "gpu_config", "cpu_config") +- Identifies which resource config this endpoint represents +- Used by ServiceRegistry for local vs remote detection + +**FLASH_MOTHERSHIP_ID** (Required for remote calls) +- Mothership endpoint ID +- Used to construct mothership URL for manifest queries +- Format: `{endpoint_id}` (constructed as `https://{FLASH_MOTHERSHIP_ID}.api.runpod.ai`) + +**FLASH_MANIFEST_PATH** (Optional) +- Override default manifest file location +- If not set, searches: cwd, module dir, parent dirs +- Useful for testing or non-standard layouts + +### Runtime Configuration + +**RUNPOD_ENDPOINT_ID** (Set by RunPod) +- This endpoint's ID +- Used for logging and identification + +--- + +## State Management + +Flash Deploy uses a dual-layer state system for reliability and consistency. + +### Local State: `.runpod/resources.pkl` + +**Purpose**: Fast resource lookup within endpoint + +**Format**: +```python +( + resources: Dict[str, DeployableResource], # Name → Resource object + config_hashes: Dict[str, str] # Name → Config hash +) +``` + +**Lifecycle**: +1. **Load**: On initialization, read from disk (shared lock) +2. **Modify**: Add/update/remove resources in memory +3. **Save**: Write to disk (exclusive lock) +4. **Persist**: File survives endpoint restarts + +**Concurrency**: +- Cross-platform file locking (fcntl on Unix, msvcrt on Windows) +- Shared lock for reads (multiple concurrent readers) +- Exclusive lock for writes (single writer) +- Prevents data corruption during concurrent access + +**Code Reference**: `src/tetra_rp/core/resources/resource_manager.py:46-150` + +### Remote State: RunPod State Manager (GraphQL API) + +**Purpose**: Persist deployment state across mothership boots + +**Data Model**: +```graphql +type FlashEnvironment { + id: ID! + name: String! + activeBuildId: ID +} + +type FlashBuild { + id: ID! + manifest: JSON! # Manifest with resource state +} +``` + +**Resource Entry** (in manifest): +```json +{ + "resource_name": { + "endpoint_id": "gpu-worker-abc123", + "endpoint_url": "https://gpu-worker-abc123.api.runpod.ai", + "config_hash": "sha256_hash_of_config", + "status": "running", + "deployed_at": "2024-01-21T10:00:00Z" + } +} +``` + +**Update Pattern** (Thread-Safe): +```python +async with state_manager_lock: + # 1. Fetch current state + build_id = await client.get_flash_environment(env_id) + manifest = await client.get_flash_build(build_id) + + # 2. Modify state + manifest["resources"][resource_name] = new_resource_data + + # 3. Write back atomically + await client.update_build_manifest(build_id, manifest) +``` + +**Reconciliation**: +On mothership boot: +1. Load local manifest (desired state) +2. Fetch persisted manifest (current state) +3. Compare → detect new, changed, removed resources +4. Update resources +5. Persist new state + +**Code Reference**: `src/tetra_rp/runtime/state_manager_client.py` + +--- + +## Testing & Debugging + +### flash test-mothership + +Local testing of mothership provisioning without deploying to RunPod. + +```bash +flash test-mothership +``` + +**What it does**: +1. Loads flash_manifest.json from current directory +2. Creates temporary resource configs (prefixed with `tmp-`) +3. Simulates mothership provisioning locally +4. Displays resource creation output +5. Auto-cleanup on exit + +**Use Cases**: +- Validate manifest structure before deployment +- Test resource provisioning logic +- Debug handler generation + +**Code Reference**: `src/tetra_rp/cli/commands/test_mothership.py` + +### Local Docker Testing + +For testing complete deployment flow locally: + +```bash +# Build project +flash build + +# Start local mothership simulator +docker run -it \ + -e FLASH_IS_MOTHERSHIP=true \ + -e RUNPOD_API_KEY=$RUNPOD_API_KEY \ + -v $(pwd)/.flash:/workspace/.flash \ + tetra-rp:latest + +# Run provisioner +python -m tetra_rp.runtime.mothership_provisioner +``` + +### Debugging Tips + +**Enable Debug Logging**: +```python +import logging +logging.basicConfig(level=logging.DEBUG) +``` + +**Inspect Manifest**: +```bash +python -c "import json; print(json.dumps(json.load(open('flash_manifest.json')), indent=2))" +``` + +**Check Local Resources**: +```python +from tetra_rp.core.resources.resource_manager import ResourceManager +rm = ResourceManager() +print(rm._resources) +print(rm._resource_configs) +``` + +**Trace Remote Calls**: +```python +# Add logging to ServiceRegistry +import logging +logging.getLogger("tetra_rp.runtime.service_registry").setLevel(logging.DEBUG) +``` + +--- + +## Key Files Reference + +### CLI + +| File | Purpose | +|------|---------| +| `src/tetra_rp/cli/commands/deploy.py` | Deploy environment management commands | +| `src/tetra_rp/cli/commands/build.py` | Build packaging and archive creation | +| `src/tetra_rp/cli/commands/test_mothership.py` | Local mothership testing | + +### Build System + +| File | Purpose | +|------|---------| +| `src/tetra_rp/cli/commands/build_utils/scanner.py` | Scans for @remote decorators | +| `src/tetra_rp/cli/commands/build_utils/handler_generator.py` | Generates queue-based handlers | +| `src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py` | Generates load-balanced handlers | +| `src/tetra_rp/cli/commands/build_utils/manifest.py` | Manifest builder and validation | + +### Resource Management + +| File | Purpose | +|------|---------| +| `src/tetra_rp/core/resources/resource_manager.py` | Resource provisioning and state tracking | +| `src/tetra_rp/core/resources/base.py` | Base resource types | +| `src/tetra_rp/core/resources/serverless.py` | Serverless resource implementations | + +### Runtime: Manifest & State + +| File | Purpose | +|------|---------| +| `src/tetra_rp/runtime/manifest_fetcher.py` | Manifest fetching with caching (GQL API) | +| `src/tetra_rp/runtime/manifest_client.py` | HTTP client for mothership manifest API | +| `src/tetra_rp/runtime/state_manager_client.py` | GraphQL client for state persistence | +| `src/tetra_rp/runtime/mothership_provisioner.py` | Auto-provisioning logic | + +### Runtime: Execution + +| File | Purpose | +|------|---------| +| `src/tetra_rp/runtime/generic_handler.py` | Queue-based handler factory | +| `src/tetra_rp/runtime/lb_handler.py` | Load-balanced (FastAPI) handler factory | +| `src/tetra_rp/runtime/service_registry.py` | Service discovery and routing | +| `src/tetra_rp/runtime/serialization.py` | cloudpickle serialization/deserialization | + +### Utilities + +| File | Purpose | +|------|---------| +| `src/tetra_rp/core/utils/file_lock.py` | Cross-platform file locking | +| `src/tetra_rp/core/utils/singleton.py` | Singleton pattern implementation | +| `src/tetra_rp/runtime/config.py` | Runtime configuration constants | + +--- + +## Common Issues & Solutions + +### Issue: Manifest not found on child endpoint + +**Cause**: flash_manifest.json not included in archive or not found at runtime + +**Solution**: +1. Verify archive contains flash_manifest.json: `tar -tzf archive.tar.gz | grep manifest` +2. Check `FLASH_MANIFEST_PATH` env var if using custom location +3. Ensure flash_manifest.json is in build root when creating archive + +### Issue: Resource provisioning fails with hash mismatch + +**Cause**: Resource config changed between boots, but hash wasn't updated + +**Solution**: +1. ResourceManager auto-refreshes hashes on init +2. Delete `.runpod/resources.pkl` to force reset +3. Check if resource config changes are intentional (then hash should differ) + +### Issue: Remote function calls fail with endpoint not found + +**Cause**: ServiceRegistry unable to query mothership or manifest outdated + +**Solution**: +1. Verify `FLASH_MOTHERSHIP_ID` environment variable is set +2. Check mothership endpoint is running: `curl https://{mothership_id}.api.runpod.ai/manifest` +3. Verify manifest includes the resource config: `grep resource_name flash_manifest.json` +4. Check network connectivity between child and mothership endpoints + +### Issue: Manifest cache staleness + +**Cause**: Manifest updated but cache still valid (TTL) + +**Solution**: +1. Default TTL: 300 seconds +2. Clear cache manually: `ManifestFetcher().invalidate_cache()` +3. Reduce TTL in production if frequent updates needed (trade-off: more API calls) + +--- + +## Performance Considerations + +### Manifest Fetching + +- **Cache TTL**: Default 300s +- **API calls**: Only on cache miss +- **Fallback**: Local file (no network) +- **Optimization**: Increase TTL for stable deployments + +### State Manager Updates + +- **GraphQL calls**: 3 per resource update (fetch env, fetch build, update build) +- **Locking**: Serializes concurrent updates +- **Recommendation**: Batch updates when provisioning many resources + +### Serialization + +- **Engine**: cloudpickle (slower than pickle for complex objects) +- **Encoding**: base64 (JSON-safe, ~33% overhead) +- **Max size**: 10MB enforced +- **Tip**: Break large function calls into batches + +### Resource Provisioning + +- **Parallel deployment**: Max 3 concurrent (configurable) +- **Retries**: 3 attempts with exponential backoff +- **Boot time**: Scales with number of resources (parallel helps) +- **Persistent state**: Enables fast boot recovery if partially failed + +--- + +## Contributing + +When modifying Flash Deploy, ensure: + +1. **Manifest schema backward compatibility**: Don't break existing manifest loading +2. **State file format**: Document any changes to `.runpod/resources.pkl` format +3. **Environment variables**: Document all new env vars in this guide +4. **Tests**: Add tests for new provisioning logic +5. **Documentation**: Update this guide for new features From 03e477ff9dc369a233491095eb9062d5e44d6cb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 22 Jan 2026 00:40:12 -0800 Subject: [PATCH 02/11] refactor: migrate service registry to peer-to-peer State Manager - Replace ManifestClient (mothership) with StateManagerClient for peer-to-peer architecture - Make get_endpoint_for_function, get_resource_for_function, and is_local_function async - Remove manifest_client parameter from ServiceRegistry.__init__ - Query State Manager directly for full manifest with extracted resources_endpoints mapping - Update docstrings to reflect peer-to-peer model and State Manager dependency - Ensure manifest cache refresh happens before routing decisions --- src/tetra_rp/runtime/service_registry.py | 87 ++++++++++++++---------- 1 file changed, 50 insertions(+), 37 deletions(-) diff --git a/src/tetra_rp/runtime/service_registry.py b/src/tetra_rp/runtime/service_registry.py index 870283bd..79eb5ac2 100644 --- a/src/tetra_rp/runtime/service_registry.py +++ b/src/tetra_rp/runtime/service_registry.py @@ -12,7 +12,7 @@ from tetra_rp.core.resources.serverless import ServerlessResource from .config import DEFAULT_CACHE_TTL -from .manifest_client import ManifestClient, ManifestServiceUnavailableError +from .state_manager_client import StateManagerClient, ManifestServiceUnavailableError from .models import Manifest logger = logging.getLogger(__name__) @@ -29,26 +29,26 @@ class ServiceRegistry: def __init__( self, manifest_path: Optional[Path] = None, - manifest_client: Optional[ManifestClient] = None, cache_ttl: int = DEFAULT_CACHE_TTL, ): - """Initialize service registry. + """Initialize service registry with peer-to-peer State Manager access. + + All endpoints query State Manager directly for manifest updates. + No Mothership dependency - all endpoints are equal peers. Args: manifest_path: Path to flash_manifest.json. Defaults to FLASH_MANIFEST_PATH env var or auto-detection. - manifest_client: Manifest service client for mothership API. If None, creates one - from FLASH_MOTHERSHIP_ID env var. cache_ttl: Manifest cache lifetime in seconds (default: 300). - Environment Variables (for local vs remote detection): - FLASH_RESOURCE_NAME: Resource config name for this endpoint (child endpoints only). - Identifies which resource config this endpoint represents in the manifest. - RUNPOD_ENDPOINT_ID: Endpoint ID (used as fallback for mothership identification). + Environment Variables: + FLASH_RESOURCE_NAME: Resource config name for this endpoint. + Identifies which resource config this endpoint represents. + RUNPOD_ENDPOINT_ID: Endpoint ID (used for State Manager queries and fallback). + RUNPOD_API_KEY: API key for State Manager GraphQL access. Raises: FileNotFoundError: If manifest_path doesn't exist. - ValueError: If required env vars missing for manifest_client. """ self.cache_ttl = cache_ttl self._endpoint_registry: Dict[str, str] = {} @@ -65,24 +65,14 @@ def __init__( # Load manifest self._load_manifest(manifest_path) - # Initialize manifest client - if manifest_client is None: - mothership_id = os.getenv("FLASH_MOTHERSHIP_ID") - if mothership_id: - try: - manifest_client = ManifestClient() - except ValueError as e: - logger.warning(f"Failed to initialize manifest client: {e}") - manifest_client = None - else: - logger.debug( - "FLASH_MOTHERSHIP_ID not set, manifest service unavailable" - ) - manifest_client = None + # Peer-to-peer: All endpoints use StateManagerClient directly + try: + self._manifest_client = StateManagerClient() + except Exception as e: + logger.warning(f"Failed to initialize State Manager client: {e}") + self._manifest_client = None - self._manifest_client = manifest_client - # Child endpoints use FLASH_RESOURCE_NAME to identify which resource config they represent - # Mothership doesn't have FLASH_RESOURCE_NAME, so falls back to RUNPOD_ENDPOINT_ID + # Current endpoint identification for local vs remote detection self._current_endpoint = os.getenv("FLASH_RESOURCE_NAME") or os.getenv( "RUNPOD_ENDPOINT_ID" ) @@ -143,36 +133,56 @@ def _load_manifest(self, manifest_path: Optional[Path]) -> None: ) async def _ensure_manifest_loaded(self) -> None: - """Load manifest from mothership if cache expired or not loaded.""" + """Load manifest from State Manager if cache expired or not loaded. + + Uses peer-to-peer model - queries State Manager directly. + """ async with self._endpoint_registry_lock: now = time.time() cache_age = now - self._endpoint_registry_loaded_at if cache_age > self.cache_ttl: if self._manifest_client is None: - logger.debug("Manifest client not available, skipping refresh") + logger.debug("State Manager client not available, skipping refresh") return try: - self._endpoint_registry = await self._manifest_client.get_manifest() + mothership_id = os.getenv("RUNPOD_ENDPOINT_ID") + if not mothership_id: + logger.warning( + "RUNPOD_ENDPOINT_ID not set, cannot query State Manager" + ) + return + + # Query State Manager directly for full manifest + full_manifest = await self._manifest_client.get_persisted_manifest( + mothership_id + ) + + # Extract resources_endpoints mapping + resources_endpoints = full_manifest.get("resources_endpoints", {}) + + self._endpoint_registry = resources_endpoints self._endpoint_registry_loaded_at = now logger.debug( - f"Manifest loaded: {len(self._endpoint_registry)} endpoints, " + f"Manifest loaded from State Manager: {len(self._endpoint_registry)} endpoints, " f"cache TTL {self.cache_ttl}s" ) except ManifestServiceUnavailableError as e: logger.warning( - f"Failed to load manifest: {e}. " + f"Failed to load manifest from State Manager: {e}. " f"Cross-endpoint routing unavailable." ) self._endpoint_registry = {} - def get_endpoint_for_function(self, function_name: str) -> Optional[str]: + async def get_endpoint_for_function(self, function_name: str) -> Optional[str]: """Get endpoint URL for a function. Determines if function is local (same endpoint) or remote (different endpoint), returning None for local and URL for remote. + Queries State Manager if endpoint registry cache is expired. + Args: function_name: Name of the function to route. @@ -182,6 +192,9 @@ def get_endpoint_for_function(self, function_name: str) -> Optional[str]: Raises: ValueError: If function not in manifest. """ + # Ensure manifest is loaded from State Manager (with caching) + await self._ensure_manifest_loaded() + function_registry = self._manifest.function_registry if function_name not in function_registry: @@ -206,7 +219,7 @@ def get_endpoint_for_function(self, function_name: str) -> Optional[str]: return endpoint_url - def get_resource_for_function( + async def get_resource_for_function( self, function_name: str ) -> Optional[ServerlessResource]: """Get ServerlessResource for a function. @@ -224,7 +237,7 @@ def get_resource_for_function( Raises: ValueError: If function not in manifest. """ - endpoint_url = self.get_endpoint_for_function(function_name) + endpoint_url = await self.get_endpoint_for_function(function_name) if endpoint_url is None: return None # Local function @@ -251,7 +264,7 @@ def get_resource_for_function( return resource - def is_local_function(self, function_name: str) -> bool: + async def is_local_function(self, function_name: str) -> bool: """Check if function executes on current endpoint. Args: @@ -261,7 +274,7 @@ def is_local_function(self, function_name: str) -> bool: True if function is local, False if remote or not found. """ try: - endpoint_url = self.get_endpoint_for_function(function_name) + endpoint_url = await self.get_endpoint_for_function(function_name) return endpoint_url is None except ValueError: # Function not in manifest, assume local (will execute and fail) From ceeac75a01e6b21d69b733a9fcf158942153dc35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 22 Jan 2026 01:09:37 -0800 Subject: [PATCH 03/11] refactor: remove obsolete hub-and-spoke infrastructure from tetra-rp Remove ManifestClient HTTP client and /manifest endpoint as they're replaced by peer-to-peer StateManagerClient model where all endpoints query State Manager directly. This eliminates single point of failure and simplifies architecture. Changes: - Delete ManifestClient (src/tetra_rp/runtime/manifest_client.py) - Delete ManifestClient tests (tests/unit/runtime/test_manifest_client.py) - Remove /manifest endpoint from lb_handler.py - Remove /manifest endpoint tests from test_lb_handler.py and test_lb_remote_execution.py - Update integration tests to use StateManagerClient mocks - Remove FLASH_MOTHERSHIP_ID environment variable references - Update documentation to reflect peer-to-peer architecture - Update CLI test-mothership command output Peer-to-peer architecture benefits: - No single point of failure (no mothership dependency) - All endpoints are equal peers - Simpler deployment model - Consistent service discovery via State Manager GraphQL API --- docs/Cross_Endpoint_Routing.md | 225 +++-------- src/tetra_rp/cli/commands/test_mothership.py | 11 +- src/tetra_rp/runtime/lb_handler.py | 50 --- src/tetra_rp/runtime/manifest_client.py | 141 ------- .../test_cross_endpoint_routing.py | 39 +- tests/integration/test_lb_remote_execution.py | 218 ----------- tests/unit/runtime/test_lb_handler.py | 370 +----------------- tests/unit/runtime/test_manifest_client.py | 207 ---------- 8 files changed, 86 insertions(+), 1175 deletions(-) delete mode 100644 src/tetra_rp/runtime/manifest_client.py delete mode 100644 tests/unit/runtime/test_manifest_client.py diff --git a/docs/Cross_Endpoint_Routing.md b/docs/Cross_Endpoint_Routing.md index 3f864780..aa83c194 100644 --- a/docs/Cross_Endpoint_Routing.md +++ b/docs/Cross_Endpoint_Routing.md @@ -59,20 +59,20 @@ The manifest structure: #### 2. Set Environment Variables -Configure the mothership manifest URL (required for remote routing): +Configure State Manager access for peer-to-peer routing: ```bash -# Required for cross-endpoint routing to work -export FLASH_MOTHERSHIP_ID=mothership-endpoint-id +# Required: API key for State Manager GraphQL access +export RUNPOD_API_KEY=your-api-key -# For child endpoints: Identifies which resource config this endpoint represents +# Optional: Identifies which resource config this endpoint represents export FLASH_RESOURCE_NAME=gpu_config -# Fallback: Used if FLASH_RESOURCE_NAME not set (for mothership identification) +# Optional: Fallback endpoint ID export RUNPOD_ENDPOINT_ID=gpu-endpoint-123 ``` -Note: Without `FLASH_MOTHERSHIP_ID`, all functions execute locally. The system gracefully falls back to local execution. +**Architecture**: Cross-endpoint routing uses a peer-to-peer model where all endpoints query State Manager directly for service discovery. No single point of failure (no hub-and-spoke). #### 3. Define Functions @@ -152,13 +152,11 @@ The manifest file (`flash_manifest.json`) defines function routing and resource | Variable | Required | Purpose | |----------|----------|---------| -| `FLASH_MOTHERSHIP_ID` | Yes* | Mothership endpoint ID for manifest service | -| `FLASH_RESOURCE_NAME` | No | Resource config name for child endpoints (identifies which resource this endpoint represents) | -| `RUNPOD_ENDPOINT_ID` | No | Fallback endpoint ID (used if FLASH_RESOURCE_NAME not set) | +| `RUNPOD_API_KEY` | Yes | State Manager GraphQL authentication | +| `FLASH_RESOURCE_NAME` | No | Resource config name for this endpoint | +| `RUNPOD_ENDPOINT_ID` | No | Fallback endpoint ID | | `FLASH_MANIFEST_PATH` | No | Explicit path to manifest file | -*Required for remote routing; without it, all functions execute locally - ### Usage Patterns #### Pattern 1: Microservice Architecture @@ -275,12 +273,12 @@ async def helper_function(x: int) -> int: **Manifest Service Unavailable** -If `FLASH_MOTHERSHIP_ID` is not set or unreachable: +If `RUNPOD_API_KEY` is not set or State Manager is unreachable: ``` -WARNING: FLASH_MOTHERSHIP_ID not set, manifest service unavailable +WARNING: State Manager unavailable, manifest service not available ``` -Functions default to local execution. Set the environment variable to enable routing. +Functions default to local execution. Set `RUNPOD_API_KEY` to enable remote routing. **Manifest Not Found** @@ -469,16 +467,17 @@ class ServiceRegistry: cache_ttl: Manifest cache lifetime in seconds (default: 300). Environment Variables (for local vs remote detection): + RUNPOD_API_KEY: API key for State Manager GraphQL access. FLASH_RESOURCE_NAME: Resource config name for this endpoint (child endpoints). Identifies which resource config this endpoint represents in the manifest. - RUNPOD_ENDPOINT_ID: Endpoint ID (used as fallback for mothership identification). + RUNPOD_ENDPOINT_ID: Endpoint ID (used as fallback for identification). """ self._load_manifest(manifest_path) - self._manifest_client = manifest_client or ManifestClient() + self._manifest_client = manifest_client or StateManagerClient() self._endpoint_registry = {} # Cached endpoint URLs self._endpoint_registry_lock = asyncio.Lock() # Child endpoints use FLASH_RESOURCE_NAME to identify which resource they represent - # Mothership doesn't have FLASH_RESOURCE_NAME, so falls back to RUNPOD_ENDPOINT_ID + # Falls back to RUNPOD_ENDPOINT_ID if not set self._current_endpoint = os.getenv("FLASH_RESOURCE_NAME") or os.getenv( "RUNPOD_ENDPOINT_ID" ) @@ -492,13 +491,14 @@ class ServiceRegistry: return self._resolve_resource(config) async def _ensure_manifest_loaded(self) -> None: - """Load manifest from mothership if cache expired or not loaded.""" + """Load manifest from State Manager if cache expired or not loaded.""" async with self._endpoint_registry_lock: now = time.time() cache_age = now - self._endpoint_registry_loaded_at if cache_age > self.cache_ttl: - self._endpoint_registry = await self._manifest_client.get_manifest() + manifest = await self._manifest_client.get_persisted_manifest(mothership_id) + self._endpoint_registry = manifest.get("resources_endpoints", {}) self._endpoint_registry_loaded_at = now ``` @@ -530,41 +530,7 @@ class ServiceRegistry: - Thread-safe with `asyncio.Lock()` - Graceful fallback if manifest service unavailable -#### 3. ManifestClient - -**Location**: `src/tetra_rp/runtime/manifest_client.py` - -HTTP client for mothership manifest service: - -```python -class ManifestClient: - """HTTP client for querying mothership manifest. - - The manifest maps resource_config names to their endpoint URLs. - Example: {"gpu_config": "https://api.runpod.io/v2/abc123"} - """ - - async def get_manifest(self) -> Dict[str, str]: - """Fetch endpoint manifest from mothership. - - Returns: - Dictionary mapping resource_config_name → endpoint_url. - Example: {"gpu_config": "https://api.runpod.io/v2/abc123"} - - Raises: - ManifestServiceUnavailableError: If manifest service unavailable after retries. - """ - # Queries {mothership_url}/manifest endpoint with retry logic -``` - -**Configuration**: -- Mothership ID from `FLASH_MOTHERSHIP_ID` env var (constructs URL as https://{id}.api.runpod.ai) -- HTTP timeout: 10 seconds (via `DEFAULT_REQUEST_TIMEOUT`) -- Retry logic: Exponential backoff with `DEFAULT_MAX_RETRIES` attempts (default: 3) -- Uses `httpx` library for async HTTP requests -- Raises `ImportError` if httpx not installed (with helpful message) - -#### 4. StateManagerClient +#### 3. StateManagerClient **Location**: `src/tetra_rp/runtime/state_manager_client.py` @@ -1010,144 +976,69 @@ import os print(f"FLASH_MOTHERSHIP_ID: {os.getenv('FLASH_MOTHERSHIP_ID')}") print(f"RUNPOD_ENDPOINT_ID: {os.getenv('RUNPOD_ENDPOINT_ID')}") -# Check manifest client directly -client = ManifestClient(mothership_url=...) -endpoints = await client.get_manifest() +# Check state manager client directly +client = StateManagerClient() +manifest = await client.get_persisted_manifest(mothership_id) ``` -## Manifest Synchronization with RunPod GraphQL API +## Peer-to-Peer Architecture with StateManagerClient ### Overview -The Mothership's GET /manifest endpoint pulls configuration from RunPod's GraphQL API, -which serves as the single source of truth for manifest data. This enables centralized -configuration management and ensures all child endpoints receive consistent routing -information. +Cross-endpoint routing uses a **peer-to-peer architecture** where all endpoints query State Manager directly for service discovery. This eliminates single points of failure and simplifies the system architecture compared to previous hub-and-spoke models. + +**Key Difference**: No mothership endpoint exposing a `/manifest` HTTP endpoint. Instead, all endpoints use `StateManagerClient` to query the RunPod GraphQL API directly. ### Architecture ```mermaid flowchart TD - A["Child Endpoint
GET /manifest"] - B["Mothership"] - C["ManifestFetcher"] - D{Cache Valid?} - E["Serve Cached
Manifest"] - F["Fetch from RunPod
GraphQL API"] - G["Update
flash_manifest.json"] - H["Cache Result
TTL: 300s"] - I["Serve Manifest"] - J["Fallback:
Load Local File"] - - A -->|Request| B - B --> C - C --> D - D -->|Yes| E - D -->|No| F - E --> I - F --> G - G --> H - H --> I - F -->|Fails| J - J --> I + A["Endpoint A"] + B["Endpoint B"] + C["Endpoint C"] + D["State Manager
GraphQL API"] + E["RunPod API Key"] + + A -->|Query Manifest| D + B -->|Query Manifest| D + C -->|Query Manifest| D + D -->|Requires| E style A fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff - style B fill:#388e3c,stroke:#1b5e20,stroke-width:3px,color:#fff - style C fill:#388e3c,stroke:#1b5e20,stroke-width:3px,color:#fff - style D fill:#f57c00,stroke:#e65100,stroke-width:3px,color:#fff - style E fill:#388e3c,stroke:#1b5e20,stroke-width:3px,color:#fff - style F fill:#d32f2f,stroke:#b71c1c,stroke-width:3px,color:#fff - style G fill:#d32f2f,stroke:#b71c1c,stroke-width:3px,color:#fff - style H fill:#388e3c,stroke:#1b5e20,stroke-width:3px,color:#fff - style I fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff - style J fill:#d32f2f,stroke:#b71c1c,stroke-width:3px,color:#fff + style B fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style C fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style D fill:#388e3c,stroke:#1b5e20,stroke-width:3px,color:#fff + style E fill:#f57c00,stroke:#e65100,stroke-width:3px,color:#fff ``` ### How It Works -1. **Source of Truth**: RunPod GraphQL API holds the authoritative manifest configuration -2. **Caching Proxy**: Mothership fetches from RunPod GQL, caches locally (5 min TTL) -3. **Local Persistence**: Fetched manifest written to `flash_manifest.json` -4. **Graceful Fallback**: If RunPod GQL unavailable, serves local file -5. **Cache Invalidation**: Automatic expiry after TTL, manual invalidation supported - -### Implementation Status - -**Current (Placeholder)**: -- `ManifestFetcher` class with caching infrastructure -- Uses existing `RunpodGraphQLClient` for API communication -- Falls back to local `flash_manifest.json` (GQL fetch raises `NotImplementedError`) -- Cache TTL: 300 seconds (configurable) - -**Future (Full Implementation)**: -- Implement `getManifest` query in `ManifestFetcher._fetch_from_gql()` -- Add `saveManifest` mutation for updating manifest in RunPod -- Real-time cache invalidation via webhooks -- Health checks and retry logic +1. **All Endpoints Equal**: Each endpoint is a peer, no master/slave relationship +2. **Direct Query**: All endpoints query State Manager GraphQL API directly +3. **No HTTP Endpoint**: No `/manifest` endpoint needed - endpoints communicate via GraphQL +4. **Caching**: ServiceRegistry caches results locally (300s TTL) to reduce State Manager load +5. **Graceful Fallback**: If State Manager unavailable, functions default to local execution ### Configuration ```bash -# Enable Mothership mode (required for /manifest endpoint) -export FLASH_IS_MOTHERSHIP=true - -# Optional: Identify this mothership instance -export RUNPOD_ENDPOINT_ID=mothership-prod-1 - -# Required for RunPod GraphQL API access +# Required: API key for State Manager GraphQL access export RUNPOD_API_KEY=your-api-key-here -``` - -### Cache Behavior -- **Default TTL**: 300 seconds (5 minutes) -- **Cache Key**: Per-mothership instance (no cross-instance cache) -- **Thread-Safe**: Uses `asyncio.Lock` for concurrent request handling -- **Manual Invalidation**: `fetcher.invalidate_cache()` for testing - -### Historical Context - -A previous `StateManagerClient` (commit b19bf7c) used REST API; the current -implementation now reads and updates manifests through RunPod GraphQL mutations. - -### Migration Guide: REST to GraphQL (PR #144) - -#### Breaking Changes - -1. **get_flash_build() signature changed**: - ```python - # Before - await client.get_flash_build({"flashBuildId": build_id}) - - # After - await client.get_flash_build(build_id) - ``` - -2. **StateManagerClient no longer uses httpx**: - - Remove `httpx` from dependencies if only used by StateManagerClient - - All operations now use RunpodGraphQLClient - -3. **Constructor parameters deprecated**: - - `base_url`: Ignored (GraphQL client manages URLs) - - `timeout`: Ignored (GraphQL client manages timeouts) - - Code continues to work but logs deprecation warnings - -#### Performance Considerations - -The new GraphQL implementation requires 3 sequential API calls per update: -1. Fetch environment to get activeBuildId -2. Fetch build to get current manifest -3. Update manifest with changes +# Optional: Identifies which resource config this endpoint represents +export FLASH_RESOURCE_NAME=gpu_config -For bulk resource updates (10+ resources), consider batching operations -to reduce latency. +# Optional: Fallback endpoint ID +export RUNPOD_ENDPOINT_ID=gpu-endpoint-123 +``` -#### Concurrency Safety +### StateManagerClient Features -StateManagerClient now uses `asyncio.Lock` to prevent race conditions -during concurrent resource updates. This ensures manifest integrity -during mothership auto-provisioning when multiple resources deploy -simultaneously. +- **GraphQL Query**: Queries RunPod GraphQL API for manifest persistence +- **Caching**: 300-second TTL cache to minimize API calls +- **Retry Logic**: Exponential backoff on failures (default 3 attempts) +- **Thread-Safe**: Uses `asyncio.Lock` for concurrent operations +- **Auto-Provisioning**: Used by mothership provisioner to update resource state ## Key Implementation Highlights diff --git a/src/tetra_rp/cli/commands/test_mothership.py b/src/tetra_rp/cli/commands/test_mothership.py index ea239e78..41e1e663 100644 --- a/src/tetra_rp/cli/commands/test_mothership.py +++ b/src/tetra_rp/cli/commands/test_mothership.py @@ -345,12 +345,13 @@ def _display_test_objectives() -> None: """Display what test-mothership tests and important warnings.""" objectives_text = """[bold cyan]What this tests:[/bold cyan] • Mothership container deployment -• Child endpoint auto-provisioning -• Manifest generation and updates +• Child endpoint auto-provisioning via State Manager +• Manifest persistence and State Manager integration [bold yellow]⚠ Important:[/bold yellow] +• Uses peer-to-peer architecture (no hub-and-spoke) +• All endpoints query State Manager directly • Child endpoints are [bold]temporary[/bold] - prefixed with 'tmp-' -• Child endpoints are [bold]non-functional[/bold] for actual workloads • All child endpoints will be [bold]automatically cleaned up[/bold] on shutdown [dim]These are test deployments only. Use 'flash deploy' for production.[/dim]""" @@ -429,9 +430,9 @@ def _run_docker_container(docker_cmd: list, port: int) -> None: console.print("[dim]Verification commands:[/dim]") console.print(f"[dim] Health: curl http://localhost:{port}/ping[/dim]") console.print( - f"[dim] Manifest (child endpoints see this):" - f" curl http://localhost:{port}/manifest[/dim]\n" + "[dim] State Manager Query: All endpoints query State Manager directly[/dim]" ) + console.print("[dim] No /manifest endpoint - peer-to-peer architecture[/dim]\n") console.print("[bold]Test phases:[/bold]") console.print(" [dim]1. Mothership startup and health check[/dim]") console.print( diff --git a/src/tetra_rp/runtime/lb_handler.py b/src/tetra_rp/runtime/lb_handler.py index ae7d5f73..7a317cd5 100644 --- a/src/tetra_rp/runtime/lb_handler.py +++ b/src/tetra_rp/runtime/lb_handler.py @@ -4,7 +4,6 @@ that handle load-balanced serverless endpoints. It supports: - User-defined HTTP routes - /execute endpoint for @remote function execution (LiveLoadBalancer only) -- /manifest endpoint for mothership service discovery (when FLASH_IS_MOTHERSHIP=true) Security Model: The /execute endpoint accepts and executes serialized function code. This is @@ -15,21 +14,14 @@ 4. In production, API authentication should protect the /execute endpoint Users should NOT expose the /execute endpoint to untrusted clients. - - The /manifest endpoint returns deployment metadata and is safe to expose - publicly as it contains only structural information about deployed functions. """ import inspect import logging -import os -from functools import lru_cache from typing import Any, Callable, Dict from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse -from .manifest_fetcher import ManifestFetcher from .serialization import ( deserialize_args, deserialize_kwargs, @@ -39,15 +31,6 @@ logger = logging.getLogger(__name__) -@lru_cache(maxsize=1) -def _get_manifest_fetcher() -> ManifestFetcher: - """Get or create the manifest fetcher singleton. - - Uses @lru_cache for thread-safe lazy initialization. - """ - return ManifestFetcher() - - def create_lb_handler( route_registry: Dict[tuple[str, str], Callable], include_execute: bool = False, @@ -184,39 +167,6 @@ async def execute_remote_function(request: Request) -> Dict[str, Any]: logger.error(f"Unexpected error in /execute endpoint: {e}") return {"success": False, "error": f"Unexpected error: {e}"} - # Register /manifest endpoint for mothership discovery (if enabled) - if os.getenv("FLASH_IS_MOTHERSHIP", "").lower() == "true": - - @app.get("/manifest") - async def get_manifest() -> JSONResponse: - """Mothership discovery endpoint. - - Fetches manifest from RunPod GraphQL API (source of truth), caches it - locally, and serves to child endpoints. Falls back to local file if - RunPod API is unavailable. - - Only available when FLASH_IS_MOTHERSHIP=true environment variable is set. - - Returns: - JSONResponse with manifest content or 404 if not found - """ - fetcher = _get_manifest_fetcher() - mothership_id = os.getenv("RUNPOD_ENDPOINT_ID") - - # Fetch manifest (from cache, RunPod GQL, or local file) - manifest_dict = await fetcher.get_manifest(mothership_id) - - if not manifest_dict or not manifest_dict.get("resources"): - return JSONResponse( - status_code=404, - content={ - "error": "Manifest not found", - "detail": "Could not load manifest from RunPod or local file", - }, - ) - - return JSONResponse(status_code=200, content=manifest_dict) - # Register user-defined routes from registry for (method, path), handler in route_registry.items(): method_upper = method.upper() diff --git a/src/tetra_rp/runtime/manifest_client.py b/src/tetra_rp/runtime/manifest_client.py deleted file mode 100644 index 68ba2184..00000000 --- a/src/tetra_rp/runtime/manifest_client.py +++ /dev/null @@ -1,141 +0,0 @@ -"""HTTP client for mothership manifest API.""" - -import asyncio -import logging -import os -from typing import Dict, Optional - -try: - import httpx -except ImportError: - httpx = None - -from .config import DEFAULT_MAX_RETRIES, DEFAULT_REQUEST_TIMEOUT -from .exceptions import ManifestServiceUnavailableError - -logger = logging.getLogger(__name__) - - -class ManifestClient: - """HTTP client for querying mothership manifest service. - - Fetches the manifest (endpoint registry) that maps resource_config names to - their deployment URLs. The manifest provides service discovery for remote - resource endpoints. - - Example: {"gpu_config": "https://api.runpod.io/v2/abc123"} - """ - - def __init__( - self, - mothership_url: Optional[str] = None, - timeout: int = DEFAULT_REQUEST_TIMEOUT, - max_retries: int = DEFAULT_MAX_RETRIES, - ): - """Initialize manifest client. - - Args: - mothership_url: Base URL of mothership endpoint. Defaults to - constructed from FLASH_MOTHERSHIP_ID environment variable. - timeout: Request timeout in seconds (default: 10). - max_retries: Maximum retry attempts (default: 3). - - Raises: - ValueError: If mothership_url not provided and env var not set. - """ - if mothership_url: - self.mothership_url = mothership_url - else: - mothership_id = os.getenv("FLASH_MOTHERSHIP_ID") - if not mothership_id: - raise ValueError( - "mothership_url required: pass mothership_url or set " - "FLASH_MOTHERSHIP_ID environment variable" - ) - self.mothership_url = f"https://{mothership_id}.api.runpod.ai" - - self.timeout = timeout - self.max_retries = max_retries - self._client: Optional[httpx.AsyncClient] = None - - async def get_manifest(self) -> Dict[str, str]: - """Fetch endpoint manifest from mothership. - - Returns: - Dictionary mapping resource_config_name → endpoint_url. - Example: {"gpu_config": "https://gpu-worker.api.runpod.ai"} - - Raises: - ManifestServiceUnavailableError: If manifest service unavailable after retries. - """ - if httpx is None: - raise ImportError( - "httpx required for ManifestClient. Install with: pip install httpx" - ) - - last_exception: Optional[Exception] = None - - for attempt in range(self.max_retries): - try: - client = await self._get_client() - response = await client.get( - f"{self.mothership_url}/manifest", - timeout=self.timeout, - ) - - if response.status_code >= 400: - raise ManifestServiceUnavailableError( - f"Manifest API returned {response.status_code}: " - f"{response.text[:200]}" - ) - - manifest = response.json() - if not isinstance(manifest, dict) or "resources" not in manifest: - raise ManifestServiceUnavailableError( - "Invalid manifest response: missing 'resources'" - ) - - logger.debug( - f"Manifest loaded: {len(manifest.get('resources', {}))} resources" - ) - return manifest - - except ( - asyncio.TimeoutError, - ManifestServiceUnavailableError, - Exception, - ) as e: - last_exception = e - if attempt < self.max_retries - 1: - backoff = 2**attempt - logger.warning( - f"Manifest service request failed (attempt {attempt + 1}): {e}, " - f"retrying in {backoff}s..." - ) - await asyncio.sleep(backoff) - continue - - raise ManifestServiceUnavailableError( - f"Failed to fetch manifest after {self.max_retries} attempts: {last_exception}" - ) - - async def _get_client(self) -> httpx.AsyncClient: - """Get or create HTTP client with proper configuration.""" - if self._client is None or self._client.is_closed: - timeout = httpx.Timeout(self.timeout) - self._client = httpx.AsyncClient(timeout=timeout) - - return self._client - - async def close(self) -> None: - """Close HTTP session.""" - if self._client and not self._client.is_closed: - await self._client.aclose() - - async def __aenter__(self): - """Async context manager entry.""" - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Async context manager exit.""" - await self.close() diff --git a/tests/integration/test_cross_endpoint_routing.py b/tests/integration/test_cross_endpoint_routing.py index c17e56ab..3ce052c4 100644 --- a/tests/integration/test_cross_endpoint_routing.py +++ b/tests/integration/test_cross_endpoint_routing.py @@ -7,13 +7,13 @@ import pytest -from tetra_rp.runtime.manifest_client import ManifestClient from tetra_rp.runtime.production_wrapper import ( ProductionWrapper, create_production_wrapper, reset_wrapper, ) from tetra_rp.runtime.service_registry import ServiceRegistry +from tetra_rp.runtime.state_manager_client import StateManagerClient class TestCrossEndpointRoutingIntegration: @@ -71,7 +71,6 @@ async def test_local_function_execution(self, manifest): "os.environ", { "FLASH_RESOURCE_NAME": "gpu_config", - "FLASH_MOTHERSHIP_ID": "mothership-test", }, ): endpoint_registry = { @@ -88,10 +87,12 @@ async def test_local_function_execution(self, manifest): try: registry = ServiceRegistry(manifest_path=manifest_path) - mock_manifest_client = AsyncMock(spec=ManifestClient) - mock_manifest_client.get_manifest.return_value = endpoint_registry + mock_state_manager = AsyncMock(spec=StateManagerClient) + mock_state_manager.get_persisted_manifest.return_value = { + "resources_endpoints": endpoint_registry + } - registry._manifest_client = mock_manifest_client + registry._manifest_client = mock_state_manager registry._endpoint_registry = endpoint_registry registry._endpoint_registry_loaded_at = float("inf") @@ -125,7 +126,6 @@ async def test_remote_function_execution_routing(self, manifest): "os.environ", { "FLASH_RESOURCE_NAME": "gpu_config", - "FLASH_MOTHERSHIP_ID": "mothership-test", }, ): endpoint_registry = { @@ -141,9 +141,11 @@ async def test_remote_function_execution_routing(self, manifest): try: registry = ServiceRegistry(manifest_path=manifest_path) - mock_manifest_client = AsyncMock(spec=ManifestClient) - mock_manifest_client.get_manifest.return_value = endpoint_registry - registry._manifest_client = mock_manifest_client + mock_state_manager = AsyncMock(spec=StateManagerClient) + mock_state_manager.get_persisted_manifest.return_value = { + "resources_endpoints": endpoint_registry + } + registry._manifest_client = mock_state_manager registry._endpoint_registry = endpoint_registry registry._endpoint_registry_loaded_at = float("inf") @@ -189,7 +191,6 @@ async def test_manifest_loading_on_demand(self, manifest): "os.environ", { "FLASH_RESOURCE_NAME": "gpu_config", - "FLASH_MOTHERSHIP_ID": "mothership-test", }, ): endpoint_registry = { @@ -205,9 +206,11 @@ async def test_manifest_loading_on_demand(self, manifest): try: registry = ServiceRegistry(manifest_path=manifest_path) - mock_manifest_client = AsyncMock(spec=ManifestClient) - mock_manifest_client.get_manifest.return_value = endpoint_registry - registry._manifest_client = mock_manifest_client + mock_state_manager = AsyncMock(spec=StateManagerClient) + mock_state_manager.get_persisted_manifest.return_value = { + "resources_endpoints": endpoint_registry + } + registry._manifest_client = mock_state_manager assert registry._endpoint_registry == {} @@ -246,7 +249,6 @@ async def test_error_handling_in_remote_execution(self, manifest): "os.environ", { "FLASH_RESOURCE_NAME": "gpu_config", - "FLASH_MOTHERSHIP_ID": "mothership-test", }, ): endpoint_registry = { @@ -262,9 +264,11 @@ async def test_error_handling_in_remote_execution(self, manifest): try: registry = ServiceRegistry(manifest_path=manifest_path) - mock_manifest_client = AsyncMock(spec=ManifestClient) - mock_manifest_client.get_manifest.return_value = endpoint_registry - registry._manifest_client = mock_manifest_client + mock_state_manager = AsyncMock(spec=StateManagerClient) + mock_state_manager.get_persisted_manifest.return_value = { + "resources_endpoints": endpoint_registry + } + registry._manifest_client = mock_state_manager registry._endpoint_registry = endpoint_registry registry._endpoint_registry_loaded_at = float("inf") @@ -316,7 +320,6 @@ def test_factory_creates_complete_system(self): "os.environ", { "RUNPOD_ENDPOINT_ID": "resource1", - "FLASH_MOTHERSHIP_ID": "mothership-test", }, ): wrapper = create_production_wrapper() diff --git a/tests/integration/test_lb_remote_execution.py b/tests/integration/test_lb_remote_execution.py index b428d6c0..1c51fd89 100644 --- a/tests/integration/test_lb_remote_execution.py +++ b/tests/integration/test_lb_remote_execution.py @@ -369,221 +369,3 @@ def test_handler_generation_with_numeric_module_paths(self): # Verify no invalid from syntax assert "from 03_advanced_workers" not in handler_code - - -class TestManifestEndpointIntegration: - """Integration tests for GET /manifest endpoint.""" - - @pytest.fixture(autouse=True) - def reset_manifest_fetcher(self): - """Reset the global manifest fetcher before each test.""" - import tetra_rp.runtime.lb_handler as lb_handler_module - - lb_handler_module._manifest_fetcher = None - yield - lb_handler_module._manifest_fetcher = None - - def test_manifest_endpoint_in_live_load_balancer(self, monkeypatch): - """Test manifest endpoint in LiveLoadBalancer with FLASH_IS_MOTHERSHIP=true.""" - from unittest.mock import patch, AsyncMock - from fastapi.testclient import TestClient - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - - # Create a LiveLoadBalancer - lb = LiveLoadBalancer(name="test-mothership") - - # Define a simple function on the mothership - @remote(lb, method="GET", path="/api/hello") - async def hello(): - return {"message": "hello"} - - # Create manifest data - test_manifest = { - "version": "1.0", - "generated_at": "2024-01-15T10:30:00Z", - "project_name": "test-app", - "resources": { - "test-mothership": { - "resource_type": "LiveLoadBalancer", - "handler_file": "handler_test_mothership.py", - "functions": [ - { - "name": "hello", - "module": "test_module", - "is_async": True, - "is_class": False, - "http_method": "GET", - "http_path": "/api/hello", - } - ], - } - }, - "function_registry": {"hello": "test-mothership"}, - "routes": {"test-mothership": {"GET /api/hello": "hello"}}, - } - - # Mock ManifestFetcher to return test manifest - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=test_manifest) - MockFetcher.return_value = mock_fetcher - - from tetra_rp.runtime.lb_handler import create_lb_handler - - # Create handler with manifest endpoint enabled - route_registry = {("GET", "/api/hello"): hello} - app = create_lb_handler(route_registry, include_execute=True) - client = TestClient(app) - - # Verify /manifest endpoint returns manifest - response = client.get("/manifest") - assert response.status_code == 200 - assert response.json() == test_manifest - - def test_manifest_endpoint_excluded_when_env_not_set(self): - """Test manifest endpoint is not available when FLASH_IS_MOTHERSHIP not set.""" - from fastapi.testclient import TestClient - from tetra_rp.runtime.lb_handler import create_lb_handler - - # Create handler without env var set - app = create_lb_handler({}, include_execute=False) - client = TestClient(app) - - # Verify /manifest returns 404 - response = client.get("/manifest") - assert response.status_code == 404 - - def test_manifest_endpoint_with_deployed_lb_resource(self, monkeypatch): - """Test manifest endpoint with LoadBalancerSlsResource.""" - from unittest.mock import patch, AsyncMock - from fastapi.testclient import TestClient - from tetra_rp.runtime.lb_handler import _get_manifest_fetcher - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - _get_manifest_fetcher.cache_clear() - - # Create test manifest for deployed endpoint - test_manifest = { - "version": "1.0", - "generated_at": "2024-01-15T10:30:00Z", - "project_name": "deployed-app", - "resources": { - "gpu-worker": { - "resource_type": "LoadBalancerSlsResource", - "handler_file": "handler_gpu_worker.py", - "functions": [ - { - "name": "process_image", - "module": "workers.gpu", - "is_async": True, - "is_class": False, - "http_method": "POST", - "http_path": "/api/process", - } - ], - } - }, - "function_registry": {"process_image": "gpu-worker"}, - } - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=test_manifest) - MockFetcher.return_value = mock_fetcher - - from tetra_rp.runtime.lb_handler import create_lb_handler - - # Create deployed handler (not LiveLoadBalancer) - app = create_lb_handler({}, include_execute=False) - client = TestClient(app) - - # Verify /manifest endpoint is available - response = client.get("/manifest") - assert response.status_code == 200 - assert response.json() == test_manifest - - _get_manifest_fetcher.cache_clear() - - def test_manifest_endpoint_coexists_with_ping(self, monkeypatch): - """Test that /manifest endpoint coexists with /ping health check.""" - from unittest.mock import patch, AsyncMock - from fastapi.testclient import TestClient - from tetra_rp.runtime.lb_handler import _get_manifest_fetcher - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - _get_manifest_fetcher.cache_clear() - - test_manifest = { - "version": "1.0", - "resources": {"test": {}}, - "function_registry": {}, - } - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=test_manifest) - MockFetcher.return_value = mock_fetcher - - from tetra_rp.runtime.lb_handler import create_lb_handler - - app = create_lb_handler({}, include_execute=False) - client = TestClient(app) - - # Verify both endpoints exist - manifest_response = client.get("/manifest") - assert manifest_response.status_code == 200 - - ping_response = client.get("/ping") - assert ping_response.status_code == 404 # Ping not auto-added by factory - - _get_manifest_fetcher.cache_clear() - - -class TestManifestClientToEndpointIntegration: - """Integration tests for ManifestClient calling GET /manifest endpoint.""" - - def test_manifest_client_can_parse_response(self): - """Test ManifestClient can parse manifest response directly.""" - import asyncio - from unittest.mock import patch, AsyncMock, MagicMock - from tetra_rp.runtime.manifest_client import ManifestClient - - # Create a manifest to simulate - test_manifest = { - "version": "1.0", - "generated_at": "2024-01-15T10:30:00Z", - "project_name": "test-app", - "resources": { - "gpu_config": { - "resource_type": "LoadBalancerSlsResource", - "handler_file": "handler_gpu.py", - "endpoint_url": "https://api.runpod.io/v2/gpu123", - } - }, - "function_registry": {"process_gpu": "gpu_config"}, - } - - async def test_client_parsing(): - # Create a mock httpx client that returns the manifest directly - mock_http_client = AsyncMock() - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = test_manifest - mock_http_client.get = AsyncMock(return_value=mock_response) - - # Create ManifestClient - client = ManifestClient(mothership_url="http://localhost:8000") - - # Mock the _get_client to return our mock - with patch.object(client, "_get_client", return_value=mock_http_client): - # Call get_manifest - should parse the response - result = await client.get_manifest() - - # Verify it successfully parsed the manifest - assert result == test_manifest - assert "gpu_config" in result["resources"] - assert result["function_registry"]["process_gpu"] == "gpu_config" - - # Run the async test - asyncio.run(test_client_parsing()) diff --git a/tests/unit/runtime/test_lb_handler.py b/tests/unit/runtime/test_lb_handler.py index 1da78a4f..55388b21 100644 --- a/tests/unit/runtime/test_lb_handler.py +++ b/tests/unit/runtime/test_lb_handler.py @@ -1,374 +1,6 @@ """Unit tests for LoadBalancer handler factory.""" -from unittest.mock import patch - -import pytest -from fastapi.testclient import TestClient - -from tetra_rp.runtime.lb_handler import create_lb_handler, _get_manifest_fetcher - - -class TestManifestEndpoint: - """Tests for GET /manifest endpoint.""" - - @pytest.fixture(autouse=True) - def reset_manifest_fetcher(self): - """Reset the manifest fetcher cache before each test.""" - _get_manifest_fetcher.cache_clear() - yield - _get_manifest_fetcher.cache_clear() - - @pytest.fixture - def sample_manifest(self): - """Sample manifest for testing.""" - return { - "version": "1.0", - "generated_at": "2024-01-15T10:30:00Z", - "project_name": "test-app", - "resources": { - "gpu_config": { - "resource_type": "LoadBalancerSlsResource", - "handler_file": "handler_gpu_config.py", - "functions": [ - { - "name": "process_image", - "module": "workers.gpu", - "is_async": True, - "is_class": False, - "http_method": "POST", - "http_path": "/api/process", - } - ], - } - }, - "function_registry": {"process_image": "gpu_config"}, - "routes": {"gpu_config": {"POST /api/process": "process_image"}}, - } - - def test_manifest_endpoint_registered_when_env_var_true( - self, sample_manifest, monkeypatch - ): - """Verify /manifest endpoint exists when FLASH_IS_MOTHERSHIP=true.""" - from unittest.mock import AsyncMock - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=sample_manifest) - MockFetcher.return_value = mock_fetcher - - app = create_lb_handler({}, include_execute=False) - routes = [route.path for route in app.routes] - - assert "/manifest" in routes - - def test_manifest_endpoint_not_registered_when_env_var_false( - self, sample_manifest, monkeypatch - ): - """Verify /manifest endpoint doesn't exist when FLASH_IS_MOTHERSHIP=false.""" - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "false") - - app = create_lb_handler({}, include_execute=False) - routes = [route.path for route in app.routes] - - assert "/manifest" not in routes - - def test_manifest_endpoint_not_registered_when_env_var_missing( - self, sample_manifest - ): - """Verify /manifest endpoint doesn't exist when env var not set.""" - app = create_lb_handler({}, include_execute=False) - client = TestClient(app) - - response = client.get("/manifest") - assert response.status_code == 404 - - def test_manifest_endpoint_returns_200_with_valid_manifest( - self, sample_manifest, monkeypatch - ): - """Test happy path - endpoint returns 200 with valid manifest.""" - from unittest.mock import AsyncMock - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=sample_manifest) - MockFetcher.return_value = mock_fetcher - - app = create_lb_handler({}, include_execute=False) - client = TestClient(app) - - response = client.get("/manifest") - - assert response.status_code == 200 - assert response.json() == sample_manifest - - def test_manifest_endpoint_returns_404_when_manifest_missing(self, monkeypatch): - """Test endpoint returns 404 when manifest file not found.""" - from unittest.mock import AsyncMock - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value={}) - MockFetcher.return_value = mock_fetcher - - app = create_lb_handler({}, include_execute=False) - client = TestClient(app) - - response = client.get("/manifest") - - assert response.status_code == 404 - data = response.json() - assert data["error"] == "Manifest not found" - assert "Could not load" in data["detail"] - - def test_manifest_endpoint_case_insensitive_env_var_true( - self, sample_manifest, monkeypatch - ): - """Test endpoint registration with different case variations of 'true'.""" - from unittest.mock import AsyncMock - - for env_value in ["True", "TRUE", "TrUe"]: - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", env_value) - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=sample_manifest) - MockFetcher.return_value = mock_fetcher - - app = create_lb_handler({}, include_execute=False) - routes = [route.path for route in app.routes] - - assert "/manifest" in routes - - def test_manifest_endpoint_case_insensitive_env_var_false(self, monkeypatch): - """Test endpoint not registered with non-'true' values.""" - for env_value in ["False", "false", "yes", "1", ""]: - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", env_value) - - app = create_lb_handler({}, include_execute=False) - routes = [route.path for route in app.routes] - - assert "/manifest" not in routes - - def test_manifest_endpoint_response_structure(self, sample_manifest, monkeypatch): - """Test that manifest response has correct structure.""" - from unittest.mock import AsyncMock - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=sample_manifest) - MockFetcher.return_value = mock_fetcher - - app = create_lb_handler({}, include_execute=False) - client = TestClient(app) - - response = client.get("/manifest") - manifest = response.json() - - # Verify manifest structure - assert "version" in manifest - assert "generated_at" in manifest - assert "project_name" in manifest - assert "resources" in manifest - assert "function_registry" in manifest - - def test_manifest_endpoint_with_empty_resources(self, monkeypatch): - """Test endpoint behavior when manifest has no resources.""" - from unittest.mock import AsyncMock - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - - empty_manifest = { - "version": "1.0", - "project_name": "test", - "resources": {}, - "function_registry": {}, - } - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=empty_manifest) - MockFetcher.return_value = mock_fetcher - - app = create_lb_handler({}, include_execute=False) - client = TestClient(app) - - response = client.get("/manifest") - - # Should return 404 if no resources - assert response.status_code == 404 - - def test_manifest_endpoint_with_none_manifest(self, monkeypatch): - """Test endpoint behavior when get_manifest returns None.""" - from unittest.mock import AsyncMock - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=None) - MockFetcher.return_value = mock_fetcher - - app = create_lb_handler({}, include_execute=False) - client = TestClient(app) - - response = client.get("/manifest") - - assert response.status_code == 404 - - def test_manifest_endpoint_coexists_with_execute( - self, sample_manifest, monkeypatch - ): - """Test that /manifest endpoint coexists with /execute endpoint.""" - from unittest.mock import AsyncMock - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=sample_manifest) - MockFetcher.return_value = mock_fetcher - - app = create_lb_handler({}, include_execute=True) - routes = [route.path for route in app.routes] - - assert "/manifest" in routes - assert "/execute" in routes - - def test_manifest_endpoint_coexists_with_user_routes( - self, sample_manifest, monkeypatch - ): - """Test that /manifest endpoint coexists with user-defined routes.""" - from unittest.mock import AsyncMock - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - - async def dummy_handler(): - return {"result": "ok"} - - route_registry = {("GET", "/api/health"): dummy_handler} - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=sample_manifest) - MockFetcher.return_value = mock_fetcher - - app = create_lb_handler(route_registry, include_execute=False) - routes = [route.path for route in app.routes] - - assert "/manifest" in routes - assert "/api/health" in routes - - def test_manifest_endpoint_content_type(self, sample_manifest, monkeypatch): - """Test that /manifest endpoint returns proper JSON content-type.""" - from unittest.mock import AsyncMock - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=sample_manifest) - MockFetcher.return_value = mock_fetcher - - app = create_lb_handler({}, include_execute=False) - client = TestClient(app) - - response = client.get("/manifest") - - assert response.headers["content-type"] == "application/json" - - def test_manifest_endpoint_with_complex_manifest(self, monkeypatch): - """Test endpoint with complex multi-resource manifest.""" - from unittest.mock import AsyncMock - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - - complex_manifest = { - "version": "1.0", - "generated_at": "2024-01-15T10:30:00Z", - "project_name": "complex-app", - "resources": { - "gpu_config": { - "resource_type": "LoadBalancerSlsResource", - "handler_file": "handler_gpu.py", - "functions": [ - { - "name": "process_gpu", - "module": "workers.gpu", - "is_async": True, - "is_class": False, - } - ], - }, - "cpu_config": { - "resource_type": "ServerlessEndpoint", - "handler_file": "handler_cpu.py", - "functions": [ - { - "name": "process_cpu", - "module": "workers.cpu", - "is_async": True, - "is_class": False, - } - ], - }, - }, - "function_registry": { - "process_gpu": "gpu_config", - "process_cpu": "cpu_config", - }, - } - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=complex_manifest) - MockFetcher.return_value = mock_fetcher - - app = create_lb_handler({}, include_execute=False) - client = TestClient(app) - - response = client.get("/manifest") - - assert response.status_code == 200 - manifest = response.json() - assert len(manifest["resources"]) == 2 - assert "gpu_config" in manifest["resources"] - assert "cpu_config" in manifest["resources"] - - def test_manifest_endpoint_uses_fetcher_with_caching( - self, sample_manifest, monkeypatch - ): - """Verify GET /manifest uses ManifestFetcher with caching.""" - from unittest.mock import AsyncMock - - monkeypatch.setenv("FLASH_IS_MOTHERSHIP", "true") - - with patch("tetra_rp.runtime.lb_handler.ManifestFetcher") as MockFetcher: - mock_fetcher = AsyncMock() - mock_fetcher.get_manifest = AsyncMock(return_value=sample_manifest) - MockFetcher.return_value = mock_fetcher - - app = create_lb_handler({}, include_execute=False) - client = TestClient(app) - - # First request - response1 = client.get("/manifest") - assert response1.status_code == 200 - assert response1.json() == sample_manifest - - # Second request - should reuse fetcher - response2 = client.get("/manifest") - assert response2.status_code == 200 - - # Verify fetcher was called (once per request) - assert mock_fetcher.get_manifest.call_count == 2 +from tetra_rp.runtime.lb_handler import create_lb_handler class TestExecuteEndpointStillWorks: diff --git a/tests/unit/runtime/test_manifest_client.py b/tests/unit/runtime/test_manifest_client.py deleted file mode 100644 index 76618363..00000000 --- a/tests/unit/runtime/test_manifest_client.py +++ /dev/null @@ -1,207 +0,0 @@ -"""Tests for ManifestClient.""" - -import asyncio -import os -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from tetra_rp.runtime.manifest_client import ( - ManifestClient, - ManifestServiceUnavailableError, -) - - -class TestManifestClient: - """Test ManifestClient functionality.""" - - @pytest.fixture - def mock_response(self): - """Mock successful HTTP response.""" - response = MagicMock() - response.status_code = 200 - response.json.return_value = { - "version": "1.0", - "generated_at": "2025-01-03T12:00:00Z", - "project_name": "test-app", - "resources": { - "gpu_config": {"endpoint_url": "https://api.runpod.io/v2/gpu123"}, - "cpu_config": {"endpoint_url": "https://api.runpod.io/v2/cpu456"}, - }, - "function_registry": {}, - } - return response - - def test_init_with_url(self): - """Test initialization with explicit URL.""" - client = ManifestClient(mothership_url="https://mothership.example.com") - assert client.mothership_url == "https://mothership.example.com" - - def test_init_from_env(self): - """Test initialization from environment variable.""" - with patch.dict(os.environ, {"FLASH_MOTHERSHIP_ID": "mothership123"}): - client = ManifestClient() - assert client.mothership_url == "https://mothership123.api.runpod.ai" - - def test_init_missing_url(self): - """Test initialization fails without URL.""" - with patch.dict(os.environ, {}, clear=True): - with pytest.raises(ValueError, match="mothership_url required"): - ManifestClient() - - def test_init_explicit_over_env(self): - """Test explicit URL takes precedence over env var.""" - with patch.dict(os.environ, {"FLASH_MOTHERSHIP_ID": "env-mothership"}): - client = ManifestClient(mothership_url="https://explicit.com") - assert client.mothership_url == "https://explicit.com" - - @pytest.mark.asyncio - async def test_get_manifest_success(self, mock_response): - """Test successful manifest fetch.""" - client = ManifestClient(mothership_url="https://mothership.example.com") - - with patch("tetra_rp.runtime.manifest_client.httpx"): - mock_client = AsyncMock() - mock_client.__aenter__.return_value = mock_client - mock_client.get.return_value = mock_response - - with patch.object(client, "_get_client", return_value=mock_client): - manifest = await client.get_manifest() - - assert manifest == { - "version": "1.0", - "generated_at": "2025-01-03T12:00:00Z", - "project_name": "test-app", - "resources": { - "gpu_config": { - "endpoint_url": "https://api.runpod.io/v2/gpu123" - }, - "cpu_config": { - "endpoint_url": "https://api.runpod.io/v2/cpu456" - }, - }, - "function_registry": {}, - } - - @pytest.mark.asyncio - async def test_get_manifest_http_error(self): - """Test handling of HTTP errors.""" - client = ManifestClient(mothership_url="https://mothership.example.com") - - response = MagicMock() - response.status_code = 500 - response.text = "Internal server error" - - with patch.object(client, "_get_client") as mock_get_client: - mock_http_client = AsyncMock() - mock_http_client.get.return_value = response - mock_http_client.is_closed = False - mock_get_client.return_value = mock_http_client - - with pytest.raises(ManifestServiceUnavailableError, match="500"): - await client.get_manifest() - - @pytest.mark.asyncio - async def test_get_manifest_timeout(self): - """Test handling of request timeout.""" - client = ManifestClient( - mothership_url="https://mothership.example.com", timeout=0.1 - ) - - with patch.object(client, "_get_client") as mock_get_client: - mock_http_client = AsyncMock() - mock_http_client.get.side_effect = asyncio.TimeoutError("Timed out") - mock_http_client.is_closed = False - mock_get_client.return_value = mock_http_client - - with pytest.raises( - ManifestServiceUnavailableError, match="after \\d+ attempts" - ): - await client.get_manifest() - - @pytest.mark.asyncio - async def test_get_manifest_retry(self): - """Test retry logic on transient failure.""" - client = ManifestClient( - mothership_url="https://mothership.example.com", max_retries=3 - ) - - response = MagicMock() - response.status_code = 200 - response.json.return_value = { - "version": "1.0", - "resources": {"gpu": {"endpoint_url": "https://gpu.example.com"}}, - "function_registry": {}, - } - - with patch.object(client, "_get_client") as mock_get_client: - mock_http_client = AsyncMock() - - # First two calls fail, third succeeds - mock_http_client.get.side_effect = [ - Exception("Network error"), - Exception("Network error"), - response, - ] - mock_http_client.is_closed = False - mock_get_client.return_value = mock_http_client - - with patch( - "tetra_rp.runtime.manifest_client.asyncio.sleep", - new_callable=AsyncMock, - ): - manifest = await client.get_manifest() - assert manifest == { - "version": "1.0", - "resources": {"gpu": {"endpoint_url": "https://gpu.example.com"}}, - "function_registry": {}, - } - assert mock_http_client.get.call_count == 3 - - @pytest.mark.asyncio - async def test_get_manifest_exhaust_retries(self): - """Test failure after exhausting retries.""" - client = ManifestClient( - mothership_url="https://mothership.example.com", max_retries=2 - ) - - with patch.object(client, "_get_client") as mock_get_client: - mock_http_client = AsyncMock() - mock_http_client.get.side_effect = Exception("Always fails") - mock_http_client.is_closed = False - mock_get_client.return_value = mock_http_client - - with patch( - "tetra_rp.runtime.manifest_client.asyncio.sleep", - new_callable=AsyncMock, - ): - with pytest.raises( - ManifestServiceUnavailableError, match="after 2 attempts" - ): - await client.get_manifest() - - @pytest.mark.asyncio - async def test_context_manager(self): - """Test async context manager.""" - client = ManifestClient(mothership_url="https://mothership.example.com") - - with patch.object(client, "close", new_callable=AsyncMock) as mock_close: - async with client: - pass - - mock_close.assert_called_once() - - @pytest.mark.asyncio - async def test_close(self): - """Test proper cleanup of HTTP client.""" - client = ManifestClient(mothership_url="https://mothership.example.com") - - with patch("tetra_rp.runtime.manifest_client.httpx"): - mock_http_client = AsyncMock() - mock_http_client.is_closed = False - - with patch.object(client, "_get_client", return_value=mock_http_client): - client._client = mock_http_client - await client.close() - - mock_http_client.aclose.assert_called_once() From 23a896b7e469387c9bfb5de2c709257474f7a6dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 22 Jan 2026 01:32:24 -0800 Subject: [PATCH 04/11] fix: resolve async test failures in cross-endpoint routing tests - Fixed test_service_registry.py: Added RUNPOD_ENDPOINT_ID env var to tests that call _ensure_manifest_loaded() (test_is_local_function_remote, test_get_resource_for_function_remote) - Fixed test_production_wrapper.py: Updated async method mocks to use AsyncMock properly for all get_resource_for_function calls - Fixed test_cross_endpoint_routing.py: Added RUNPOD_ENDPOINT_ID env var to test_manifest_loading_on_demand - Fixed conftest.py: Made worker_temp_dir fixture compatible with both parallel (xdist) and serial execution by using request fixture to detect worker_id instead of requiring it All 707 tests now pass with 65% coverage. Tests use peer-to-peer routing via StateManagerClient as per recent hub-and-spoke cleanup. --- src/tetra_rp/runtime/production_wrapper.py | 6 +- tests/conftest.py | 13 +- .../test_cross_endpoint_routing.py | 1 + tests/unit/runtime/test_production_wrapper.py | 16 +- tests/unit/runtime/test_service_registry.py | 199 ++++++++++-------- 5 files changed, 141 insertions(+), 94 deletions(-) diff --git a/src/tetra_rp/runtime/production_wrapper.py b/src/tetra_rp/runtime/production_wrapper.py index 22a48f9e..eb600ced 100644 --- a/src/tetra_rp/runtime/production_wrapper.py +++ b/src/tetra_rp/runtime/production_wrapper.py @@ -61,7 +61,9 @@ async def wrap_function_execution( # Determine routing try: - resource = self.service_registry.get_resource_for_function(function_name) + resource = await self.service_registry.get_resource_for_function( + function_name + ) except ValueError as e: # Function not in manifest, execute locally logger.debug( @@ -126,7 +128,7 @@ async def wrap_class_method_execution( # Determine routing try: - resource = self.service_registry.get_resource_for_function(class_name) + resource = await self.service_registry.get_resource_for_function(class_name) except ValueError: # Class not in manifest, execute locally logger.debug(f"Class {class_name} not in manifest, executing locally") diff --git a/tests/conftest.py b/tests/conftest.py index 478855aa..8ced0595 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -210,7 +210,9 @@ def sample_pod_template() -> Dict[str, Any]: @pytest.fixture(scope="session") -def worker_temp_dir(tmp_path_factory: pytest.TempPathFactory, worker_id: str) -> Path: +def worker_temp_dir( + tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest +) -> Path: """Provide worker-specific temporary directory for file system isolation. Each xdist worker gets its own isolated temp directory to prevent @@ -218,11 +220,18 @@ def worker_temp_dir(tmp_path_factory: pytest.TempPathFactory, worker_id: str) -> Args: tmp_path_factory: Pytest's temporary path factory. - worker_id: Worker ID from pytest-xdist ('master' for single worker). + request: Pytest request object. Returns: Path to worker-specific temporary directory. """ + # Try to get worker_id from xdist, default to 'master' if not available + worker_id = getattr(request.config, "workerinput", None) + if worker_id is None: + worker_id = "master" + else: + worker_id = worker_id["workerid"] + if worker_id == "master": # Single worker (non-parallel) return tmp_path_factory.mktemp("test_data") diff --git a/tests/integration/test_cross_endpoint_routing.py b/tests/integration/test_cross_endpoint_routing.py index 3ce052c4..9e8a8349 100644 --- a/tests/integration/test_cross_endpoint_routing.py +++ b/tests/integration/test_cross_endpoint_routing.py @@ -191,6 +191,7 @@ async def test_manifest_loading_on_demand(self, manifest): "os.environ", { "FLASH_RESOURCE_NAME": "gpu_config", + "RUNPOD_ENDPOINT_ID": "mothership-id", }, ): endpoint_registry = { diff --git a/tests/unit/runtime/test_production_wrapper.py b/tests/unit/runtime/test_production_wrapper.py index bda5c31d..5b7945fa 100644 --- a/tests/unit/runtime/test_production_wrapper.py +++ b/tests/unit/runtime/test_production_wrapper.py @@ -46,7 +46,7 @@ async def test_wrap_function_local_execution( self, wrapper, mock_registry, original_stub, sample_function ): """Test routing local function to original stub.""" - mock_registry.get_resource_for_function.return_value = None + mock_registry.get_resource_for_function = AsyncMock(return_value=None) await wrapper.wrap_function_execution( original_stub, @@ -74,7 +74,7 @@ async def test_wrap_function_remote_execution( mock_resource.run_sync = AsyncMock() mock_resource.run_sync.return_value = MagicMock(error="", output=42) - mock_registry.get_resource_for_function.return_value = mock_resource + mock_registry.get_resource_for_function = AsyncMock(return_value=mock_resource) result = await wrapper.wrap_function_execution( original_stub, @@ -97,8 +97,8 @@ async def test_wrap_function_not_in_manifest( self, wrapper, mock_registry, original_stub, sample_function ): """Test function not found in manifest executes locally.""" - mock_registry.get_resource_for_function.side_effect = ValueError( - "Function not found" + mock_registry.get_resource_for_function = AsyncMock( + side_effect=ValueError("Function not found") ) await wrapper.wrap_function_execution( @@ -123,7 +123,7 @@ async def test_wrap_function_remote_error( mock_resource.run_sync = AsyncMock() mock_resource.run_sync.return_value = MagicMock(error="Remote execution failed") - mock_registry.get_resource_for_function.return_value = mock_resource + mock_registry.get_resource_for_function = AsyncMock(return_value=mock_resource) with pytest.raises(Exception, match="Remote execution failed"): await wrapper.wrap_function_execution( @@ -137,7 +137,7 @@ async def test_wrap_function_remote_error( @pytest.mark.asyncio async def test_wrap_function_loads_manifest(self, wrapper, mock_registry): """Test that manifest is loaded before routing decision.""" - mock_registry.get_resource_for_function.return_value = None + mock_registry.get_resource_for_function = AsyncMock(return_value=None) async def sample_func(): pass @@ -156,7 +156,7 @@ async def test_wrap_class_method_local(self, wrapper, mock_registry, original_st request = MagicMock() request.class_name = "MyClass" - mock_registry.get_resource_for_function.return_value = None + mock_registry.get_resource_for_function = AsyncMock(return_value=None) await wrapper.wrap_class_method_execution(original_stub, request) @@ -184,7 +184,7 @@ async def test_wrap_class_method_remote( mock_resource.run_sync = AsyncMock() mock_resource.run_sync.return_value = MagicMock(error="", output="done") - mock_registry.get_resource_for_function.return_value = mock_resource + mock_registry.get_resource_for_function = AsyncMock(return_value=mock_resource) result = await wrapper.wrap_class_method_execution(original_stub, request) diff --git a/tests/unit/runtime/test_service_registry.py b/tests/unit/runtime/test_service_registry.py index 7bf208ea..124e5eab 100644 --- a/tests/unit/runtime/test_service_registry.py +++ b/tests/unit/runtime/test_service_registry.py @@ -103,81 +103,100 @@ def test_get_current_endpoint_id_not_set(self): registry = ServiceRegistry(manifest_path=Path("/nonexistent")) assert registry.get_current_endpoint_id() is None - def test_is_local_function_local(self, manifest_file): + @pytest.mark.asyncio + async def test_is_local_function_local(self, manifest_file): """Test determining local function using FLASH_RESOURCE_NAME.""" with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): registry = ServiceRegistry(manifest_path=manifest_file) - assert registry.is_local_function("gpu_task") is True - assert registry.is_local_function("inference") is True + assert await registry.is_local_function("gpu_task") is True + assert await registry.is_local_function("inference") is True - def test_is_local_function_remote(self, manifest_file): + @pytest.mark.asyncio + async def test_is_local_function_remote(self, manifest_file): """Test determining remote function (with manifest loaded).""" - with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): + with patch.dict( + os.environ, + { + "FLASH_RESOURCE_NAME": "gpu_config", + "RUNPOD_ENDPOINT_ID": "mothership-id", + }, + ): + registry = ServiceRegistry(manifest_path=manifest_file) + + # Mock the manifest client mock_client = AsyncMock() - mock_client.get_manifest.return_value = { - "cpu_config": "https://cpu.example.com" + mock_client.get_persisted_manifest.return_value = { + "resources_endpoints": {"cpu_config": "https://cpu.example.com"} } + registry._manifest_client = mock_client - registry = ServiceRegistry( - manifest_path=manifest_file, manifest_client=mock_client - ) # After manifest is loaded, CPU tasks should be recognized as remote - # (but is_local_function doesn't async load, so returns True for now) - # This is actually expected behavior - sync method can't load async manifest - assert registry.is_local_function("preprocess") is True + await registry._ensure_manifest_loaded() + assert await registry.is_local_function("preprocess") is False - def test_is_local_function_not_in_manifest(self, manifest_file): + @pytest.mark.asyncio + async def test_is_local_function_not_in_manifest(self, manifest_file): """Test function not in manifest.""" registry = ServiceRegistry(manifest_path=manifest_file) # Unknown function assumed local - assert registry.is_local_function("unknown_function") is True + assert await registry.is_local_function("unknown_function") is True - def test_get_endpoint_for_function_local(self, manifest_file): + @pytest.mark.asyncio + async def test_get_endpoint_for_function_local(self, manifest_file): """Test getting endpoint for local function using FLASH_RESOURCE_NAME.""" with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): registry = ServiceRegistry(manifest_path=manifest_file) - endpoint = registry.get_endpoint_for_function("gpu_task") + endpoint = await registry.get_endpoint_for_function("gpu_task") assert endpoint is None # Local returns None - def test_get_endpoint_for_function_remote_no_manifest(self, manifest_file): + @pytest.mark.asyncio + async def test_get_endpoint_for_function_remote_no_manifest(self, manifest_file): """Test getting endpoint for remote function without manifest.""" with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): registry = ServiceRegistry(manifest_path=manifest_file) # CPU function is remote, but no manifest loaded - endpoint = registry.get_endpoint_for_function("preprocess") + endpoint = await registry.get_endpoint_for_function("preprocess") assert endpoint is None - def test_get_endpoint_for_function_not_in_manifest(self, manifest_file): + @pytest.mark.asyncio + async def test_get_endpoint_for_function_not_in_manifest(self, manifest_file): """Test getting endpoint for unknown function.""" registry = ServiceRegistry(manifest_path=manifest_file) with pytest.raises(ValueError, match="not found in manifest"): - registry.get_endpoint_for_function("unknown_function") + await registry.get_endpoint_for_function("unknown_function") - def test_get_resource_for_function_local(self, manifest_file): + @pytest.mark.asyncio + async def test_get_resource_for_function_local(self, manifest_file): """Test getting ServerlessResource for local function using FLASH_RESOURCE_NAME.""" with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): registry = ServiceRegistry(manifest_path=manifest_file) - resource = registry.get_resource_for_function("gpu_task") + resource = await registry.get_resource_for_function("gpu_task") # Local function returns None assert resource is None - def test_get_resource_for_function_remote(self, manifest_file): + @pytest.mark.asyncio + async def test_get_resource_for_function_remote(self, manifest_file): """Test getting ServerlessResource for remote function.""" - with patch.dict(os.environ, {"FLASH_RESOURCE_NAME": "gpu_config"}): + with patch.dict( + os.environ, + { + "FLASH_RESOURCE_NAME": "gpu_config", + "RUNPOD_ENDPOINT_ID": "mothership-id", + }, + ): + registry = ServiceRegistry(manifest_path=manifest_file) + + # Mock the manifest client mock_client = AsyncMock() - mock_client.get_manifest.return_value = { - "cpu_config": "https://api.runpod.io/v2/abc123" + mock_client.get_persisted_manifest.return_value = { + "resources_endpoints": {"cpu_config": "https://api.runpod.io/v2/abc123"} } + registry._manifest_client = mock_client - registry = ServiceRegistry( - manifest_path=manifest_file, manifest_client=mock_client - ) - # Manually set endpoint registry to simulate loaded state - registry._endpoint_registry = { - "cpu_config": "https://api.runpod.io/v2/abc123" - } + # Load manifest + await registry._ensure_manifest_loaded() - resource = registry.get_resource_for_function("preprocess") + resource = await registry.get_resource_for_function("preprocess") # Should return ServerlessResource assert resource is not None @@ -185,11 +204,12 @@ def test_get_resource_for_function_remote(self, manifest_file): # Name starts with remote_preprocess (may have random suffix appended) assert resource.name.startswith("remote_preprocess") - def test_get_resource_for_function_not_in_manifest(self, manifest_file): + @pytest.mark.asyncio + async def test_get_resource_for_function_not_in_manifest(self, manifest_file): """Test getting resource for unknown function.""" registry = ServiceRegistry(manifest_path=manifest_file) with pytest.raises(ValueError, match="not found in manifest"): - registry.get_resource_for_function("unknown_function") + await registry.get_resource_for_function("unknown_function") @pytest.mark.asyncio async def test_ensure_manifest_loaded(self, manifest_file): @@ -199,70 +219,79 @@ async def test_ensure_manifest_loaded(self, manifest_file): "cpu_config": "https://cpu.example.com", } - mock_client = AsyncMock() - mock_client.get_manifest.return_value = mock_endpoint_registry + with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "mothership-id"}): + registry = ServiceRegistry(manifest_path=manifest_file, cache_ttl=10) - registry = ServiceRegistry( - manifest_path=manifest_file, manifest_client=mock_client, cache_ttl=10 - ) + # Mock the manifest client + mock_client = AsyncMock() + mock_client.get_persisted_manifest.return_value = { + "resources_endpoints": mock_endpoint_registry + } + registry._manifest_client = mock_client - # Endpoint registry not loaded yet - assert registry._endpoint_registry == {} + # Endpoint registry not loaded yet + assert registry._endpoint_registry == {} - # Load manifest - await registry._ensure_manifest_loaded() + # Load manifest + await registry._ensure_manifest_loaded() - # Should now have loaded endpoint registry - assert registry._endpoint_registry == mock_endpoint_registry - mock_client.get_manifest.assert_called_once() + # Should now have loaded endpoint registry + assert registry._endpoint_registry == mock_endpoint_registry + mock_client.get_persisted_manifest.assert_called_once() @pytest.mark.asyncio async def test_ensure_manifest_cache_respects_ttl(self, manifest_file): """Test that manifest cache respects TTL.""" mock_endpoint_registry = {"gpu_config": "https://gpu.example.com"} - mock_client = AsyncMock() - mock_client.get_manifest.return_value = mock_endpoint_registry + with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "mothership-id"}): + registry = ServiceRegistry(manifest_path=manifest_file, cache_ttl=1) - registry = ServiceRegistry( - manifest_path=manifest_file, manifest_client=mock_client, cache_ttl=1 - ) + # Mock the manifest client + mock_client = AsyncMock() + mock_client.get_persisted_manifest.return_value = { + "resources_endpoints": mock_endpoint_registry + } + registry._manifest_client = mock_client - # Load manifest - await registry._ensure_manifest_loaded() - assert mock_client.get_manifest.call_count == 1 + # Load manifest + await registry._ensure_manifest_loaded() + assert mock_client.get_persisted_manifest.call_count == 1 - # Immediate reload should use cache - await registry._ensure_manifest_loaded() - assert mock_client.get_manifest.call_count == 1 + # Immediate reload should use cache + await registry._ensure_manifest_loaded() + assert mock_client.get_persisted_manifest.call_count == 1 - # After TTL, should reload - registry._endpoint_registry_loaded_at = time.time() - 2 # 2 seconds ago - await registry._ensure_manifest_loaded() - assert mock_client.get_manifest.call_count == 2 + # After TTL, should reload + registry._endpoint_registry_loaded_at = time.time() - 2 # 2 seconds ago + await registry._ensure_manifest_loaded() + assert mock_client.get_persisted_manifest.call_count == 2 @pytest.mark.asyncio async def test_refresh_manifest(self, manifest_file): """Test forcing manifest refresh.""" mock_endpoint_registry = {"gpu_config": "https://gpu.example.com"} - mock_client = AsyncMock() - mock_client.get_manifest.return_value = mock_endpoint_registry + with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "mothership-id"}): + registry = ServiceRegistry(manifest_path=manifest_file, cache_ttl=3600) - registry = ServiceRegistry( - manifest_path=manifest_file, manifest_client=mock_client, cache_ttl=3600 - ) + # Mock the manifest client + mock_client = AsyncMock() + mock_client.get_persisted_manifest.return_value = { + "resources_endpoints": mock_endpoint_registry + } + registry._manifest_client = mock_client - # Load manifest - await registry._ensure_manifest_loaded() - assert mock_client.get_manifest.call_count == 1 + # Load manifest + await registry._ensure_manifest_loaded() + assert mock_client.get_persisted_manifest.call_count == 1 - # Force refresh - registry.refresh_manifest() + # Force refresh + registry.refresh_manifest() - # Next load should fetch again - await registry._ensure_manifest_loaded() - assert mock_client.get_manifest.call_count == 2 + # Next load should fetch again + await registry._ensure_manifest_loaded() + assert mock_client.get_persisted_manifest.call_count == 2 def test_get_manifest(self, manifest_file): """Test getting manifest.""" @@ -292,16 +321,22 @@ def test_get_resource_functions_not_found(self, manifest_file): functions = registry.get_resource_functions("nonexistent") assert functions == [] - def test_init_no_manifest_client_no_mothership_id(self, manifest_file): - """Test initialization without manifest client or FLASH_MOTHERSHIP_ID.""" + def test_init_no_manifest_client_no_runpod_key(self, manifest_file): + """Test initialization without RUNPOD_API_KEY.""" with patch.dict(os.environ, {}, clear=True): - registry = ServiceRegistry(manifest_path=manifest_file) - assert registry._manifest_client is None + with patch( + "tetra_rp.runtime.service_registry.StateManagerClient" + ) as mock_client_class: + mock_client_class.side_effect = Exception("No API key") + registry = ServiceRegistry(manifest_path=manifest_file) + # Should handle the exception and set client to None + assert registry._manifest_client is None @pytest.mark.asyncio async def test_ensure_manifest_loaded_unavailable_client(self, manifest_file): """Test manifest loading when client is None.""" - registry = ServiceRegistry(manifest_path=manifest_file, manifest_client=None) + registry = ServiceRegistry(manifest_path=manifest_file) + registry._manifest_client = None # Should not fail, just log warning await registry._ensure_manifest_loaded() assert registry._endpoint_registry == {} From 2affb7af6606ac8f4f6813c5655a6b4b5486278b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 22 Jan 2026 08:42:49 -0800 Subject: [PATCH 05/11] feat: implement cross-endpoint routing reliability infrastructure Cleanup: - Remove /manifest endpoint from LoadBalancer handler generator (lb_handler_generator.py) - Complete peer-to-peer architecture migration - all endpoints now query State Manager directly Phase 2 Foundation (Reliability Features): - Add reliability_config.py: Centralized configuration for circuit breaker, load balancing, retry logic, and metrics collection with environment variable support - Add circuit_breaker.py: Circuit breaker pattern implementation with state machine (CLOSED, OPEN, HALF_OPEN) using sliding window failure detection - Add load_balancer.py: Load balancing strategies (round-robin, least-connections, random) for distributing requests across multiple endpoints - Add retry_manager.py: Retry logic with exponential backoff, jitter, and circuit breaker integration for handling transient failures - Add metrics.py: Structured logging-based metrics collection with helpers for circuit breaker, retry, and load balancer telemetry Next: Phase 2 will implement upfront provisioning in deployment flow. --- .../build_utils/lb_handler_generator.py | 59 +--- src/tetra_rp/runtime/circuit_breaker.py | 272 +++++++++++++++ src/tetra_rp/runtime/load_balancer.py | 160 +++++++++ src/tetra_rp/runtime/metrics.py | 325 ++++++++++++++++++ src/tetra_rp/runtime/reliability_config.py | 148 ++++++++ src/tetra_rp/runtime/retry_manager.py | 118 +++++++ tests/unit/runtime/test_circuit_breaker.py | 188 ++++++++++ tests/unit/runtime/test_load_balancer.py | 145 ++++++++ tests/unit/runtime/test_reliability_config.py | 181 ++++++++++ tests/unit/runtime/test_retry_manager.py | 165 +++++++++ 10 files changed, 1703 insertions(+), 58 deletions(-) create mode 100644 src/tetra_rp/runtime/circuit_breaker.py create mode 100644 src/tetra_rp/runtime/load_balancer.py create mode 100644 src/tetra_rp/runtime/metrics.py create mode 100644 src/tetra_rp/runtime/reliability_config.py create mode 100644 src/tetra_rp/runtime/retry_manager.py create mode 100644 tests/unit/runtime/test_circuit_breaker.py create mode 100644 tests/unit/runtime/test_load_balancer.py create mode 100644 tests/unit/runtime/test_reliability_config.py create mode 100644 tests/unit/runtime/test_retry_manager.py diff --git a/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py b/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py index b639773e..3f820475 100644 --- a/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py +++ b/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py @@ -40,9 +40,6 @@ {registry} }} -# Module-level state for /manifest endpoint -_state_client: Optional[StateManagerClient] = None - # Lifespan context manager for startup/shutdown @asynccontextmanager @@ -69,10 +66,8 @@ async def lifespan(app: FastAPI): mothership_url = get_mothership_url() logger.info(f"Mothership URL: {{mothership_url}}") - # Initialize State Manager client and store in module-level state + # Initialize State Manager client for provisioning state_client = StateManagerClient() - global _state_client - _state_client = state_client # Spawn background provisioning task (non-blocking) manifest_path = Path(__file__).parent / "flash_manifest.json" @@ -117,58 +112,6 @@ def ping(): return {{"status": "healthy"}} -# Manifest endpoint for service discovery -@app.get("/manifest") -async def manifest(): - """Return complete authoritative manifest for service discovery. - - Fetches the full manifest from State Manager, allowing child endpoints - to synchronize their configuration. - - Returns: - dict: Complete manifest with version, generated_at, project_name, - function_registry, resources, and routes - """ - try: - import os - from tetra_rp.runtime.mothership_provisioner import is_mothership - - # Only mothership serves manifest - if not is_mothership(): - return {{"error": "Only mothership serves manifest"}}, 403 - - # Check state client initialized - global _state_client - if _state_client is None: - return {{"error": "State Manager not initialized"}}, 500 - - # Get mothership ID - mothership_id = os.getenv("RUNPOD_ENDPOINT_ID") - if not mothership_id: - return {{"error": "RUNPOD_ENDPOINT_ID not set"}}, 500 - - # Fetch persisted manifest from State Manager (single source of truth) - persisted_manifest = await _state_client.get_persisted_manifest(mothership_id) - - # First boot: no manifest yet, return minimal structure - if persisted_manifest is None: - return {{ - "version": "1.0", - "generated_at": "", - "project_name": "", - "function_registry": {{}}, - "resources": {{}}, - "routes": {{}} - }} - - # Return complete manifest - return persisted_manifest - - except Exception as e: - logger.error(f"Failed to get manifest: {{e}}") - return {{"error": str(e)}}, 500 - - if __name__ == "__main__": import uvicorn # Local development server for testing diff --git a/src/tetra_rp/runtime/circuit_breaker.py b/src/tetra_rp/runtime/circuit_breaker.py new file mode 100644 index 00000000..8065fa43 --- /dev/null +++ b/src/tetra_rp/runtime/circuit_breaker.py @@ -0,0 +1,272 @@ +"""Circuit breaker pattern for handling endpoint failures.""" + +import asyncio +import logging +import time +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +class CircuitState(Enum): + """Circuit breaker state machine.""" + + CLOSED = "closed" + OPEN = "open" + HALF_OPEN = "half_open" + + +@dataclass +class CircuitBreakerStats: + """Statistics for a circuit breaker instance.""" + + state: CircuitState = CircuitState.CLOSED + failure_count: int = 0 + success_count: int = 0 + last_failure_at: Optional[datetime] = None + last_success_at: Optional[datetime] = None + state_changed_at: datetime = field(default_factory=datetime.utcnow) + total_requests: int = 0 + total_failures: int = 0 + + +class EndpointCircuitBreaker: + """Circuit breaker for a single endpoint with sliding window.""" + + def __init__( + self, + endpoint_url: str, + failure_threshold: int = 5, + success_threshold: int = 2, + timeout_seconds: int = 60, + window_size: int = 10, + ): + """Initialize circuit breaker for an endpoint. + + Args: + endpoint_url: URL of the endpoint to protect + failure_threshold: Failures required to open circuit + success_threshold: Successes required to close circuit + timeout_seconds: Time before attempting recovery + window_size: Size of sliding window for counting failures + """ + self.endpoint_url = endpoint_url + self.failure_threshold = failure_threshold + self.success_threshold = success_threshold + self.timeout_seconds = timeout_seconds + self.window_size = window_size + self.stats = CircuitBreakerStats() + self._lock = asyncio.Lock() + self._failure_times: list[float] = [] + + async def execute(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + """Execute function with circuit breaker protection. + + Args: + func: Async function to execute + *args: Positional arguments for func + **kwargs: Keyword arguments for func + + Returns: + Result from func + + Raises: + CircuitBreakerOpenError: If circuit is open + Exception: Any exception raised by func + """ + async with self._lock: + state = self.stats.state + + if state == CircuitState.OPEN: + # Check if timeout has passed + if self._should_attempt_recovery(): + self._transition_to_half_open() + else: + raise CircuitBreakerOpenError( + f"Circuit OPEN for {self.endpoint_url}. " + f"Retry in {self._seconds_until_recovery()}s" + ) + + # Execute function + try: + result = await func(*args, **kwargs) + await self._on_success() + return result + except Exception as e: + await self._on_failure(e) + raise + + async def _on_success(self) -> None: + """Record successful request.""" + async with self._lock: + self.stats.success_count += 1 + self.stats.total_requests += 1 + self.stats.last_success_at = datetime.utcnow() + + logger.debug( + f"Circuit breaker {self.endpoint_url}: " + f"success {self.stats.success_count}/{self.success_threshold}" + ) + + if self.stats.state == CircuitState.HALF_OPEN: + # Close circuit after successes + if self.stats.success_count >= self.success_threshold: + self._transition_to_closed() + elif self.stats.state == CircuitState.CLOSED: + # Reset failure count on success + self.stats.failure_count = 0 + self._failure_times.clear() + + async def _on_failure(self, error: Exception) -> None: + """Record failed request.""" + async with self._lock: + self.stats.failure_count += 1 + self.stats.total_failures += 1 + self.stats.total_requests += 1 + self.stats.last_failure_at = datetime.utcnow() + + # Track failure times for sliding window + now = time.time() + self._failure_times.append(now) + + # Keep only failures within window + cutoff = now - self.timeout_seconds + self._failure_times = [t for t in self._failure_times if t > cutoff] + + logger.debug( + f"Circuit breaker {self.endpoint_url}: " + f"failure {self.stats.failure_count}/{self.failure_threshold}, " + f"error: {error}" + ) + + if self.stats.state == CircuitState.HALF_OPEN: + # Open circuit on first failure in half-open + self._transition_to_open() + elif self.stats.state == CircuitState.CLOSED: + # Open circuit if threshold reached + if len(self._failure_times) >= self.failure_threshold: + self._transition_to_open() + + def _transition_to_open(self) -> None: + """Transition circuit to OPEN state.""" + if self.stats.state == CircuitState.OPEN: + return # Already open + self.stats.state = CircuitState.OPEN + self.stats.state_changed_at = datetime.utcnow() + self.stats.success_count = 0 + logger.warning( + f"Circuit breaker OPEN for {self.endpoint_url} " + f"after {self.stats.failure_count} failures" + ) + + def _transition_to_half_open(self) -> None: + """Transition circuit to HALF_OPEN state.""" + self.stats.state = CircuitState.HALF_OPEN + self.stats.state_changed_at = datetime.utcnow() + self.stats.failure_count = 0 + self.stats.success_count = 0 + logger.info( + f"Circuit breaker HALF_OPEN for {self.endpoint_url}, testing recovery" + ) + + def _transition_to_closed(self) -> None: + """Transition circuit to CLOSED state.""" + self.stats.state = CircuitState.CLOSED + self.stats.state_changed_at = datetime.utcnow() + self.stats.failure_count = 0 + self.stats.success_count = 0 + self._failure_times.clear() + logger.info(f"Circuit breaker CLOSED for {self.endpoint_url}, recovered") + + def _should_attempt_recovery(self) -> bool: + """Check if enough time has passed to attempt recovery.""" + if not self.stats.last_failure_at: + return False + elapsed = datetime.utcnow() - self.stats.state_changed_at + return elapsed.total_seconds() >= self.timeout_seconds + + def _seconds_until_recovery(self) -> int: + """Get seconds until recovery can be attempted.""" + if not self.stats.state_changed_at: + return self.timeout_seconds + elapsed = datetime.utcnow() - self.stats.state_changed_at + remaining = self.timeout_seconds - int(elapsed.total_seconds()) + return max(0, remaining) + + def get_state(self) -> CircuitState: + """Get current circuit state.""" + return self.stats.state + + def get_stats(self) -> CircuitBreakerStats: + """Get circuit breaker statistics.""" + return self.stats + + +class CircuitBreakerRegistry: + """Manages circuit breakers for multiple endpoints.""" + + def __init__( + self, + failure_threshold: int = 5, + success_threshold: int = 2, + timeout_seconds: int = 60, + ): + """Initialize circuit breaker registry. + + Args: + failure_threshold: Failures required to open circuit + success_threshold: Successes required to close circuit + timeout_seconds: Time before attempting recovery + """ + self.failure_threshold = failure_threshold + self.success_threshold = success_threshold + self.timeout_seconds = timeout_seconds + self._breakers: dict[str, EndpointCircuitBreaker] = {} + self._lock = asyncio.Lock() + + def get_breaker(self, endpoint_url: str) -> EndpointCircuitBreaker: + """Get or create circuit breaker for endpoint. + + Args: + endpoint_url: URL of the endpoint + + Returns: + EndpointCircuitBreaker instance + """ + if endpoint_url not in self._breakers: + self._breakers[endpoint_url] = EndpointCircuitBreaker( + endpoint_url, + failure_threshold=self.failure_threshold, + success_threshold=self.success_threshold, + timeout_seconds=self.timeout_seconds, + ) + return self._breakers[endpoint_url] + + def get_state(self, endpoint_url: str) -> CircuitState: + """Get state of circuit breaker for endpoint. + + Args: + endpoint_url: URL of the endpoint + + Returns: + Current circuit state + """ + breaker = self.get_breaker(endpoint_url) + return breaker.get_state() + + def get_all_stats(self) -> dict[str, CircuitBreakerStats]: + """Get statistics for all circuit breakers. + + Returns: + Mapping of endpoint URLs to statistics + """ + return {url: breaker.get_stats() for url, breaker in self._breakers.items()} + + +class CircuitBreakerOpenError(Exception): + """Raised when circuit breaker is open.""" + + pass diff --git a/src/tetra_rp/runtime/load_balancer.py b/src/tetra_rp/runtime/load_balancer.py new file mode 100644 index 00000000..3e7272ad --- /dev/null +++ b/src/tetra_rp/runtime/load_balancer.py @@ -0,0 +1,160 @@ +"""Load balancing strategies for distributed endpoint routing.""" + +import asyncio +import logging +import random +from typing import TYPE_CHECKING, List, Optional + +from tetra_rp.runtime.reliability_config import LoadBalancerStrategy + +if TYPE_CHECKING: + from tetra_rp.runtime.circuit_breaker import CircuitBreakerRegistry + +logger = logging.getLogger(__name__) + + +class LoadBalancer: + """Load balancer for selecting endpoints using various strategies.""" + + def __init__( + self, strategy: LoadBalancerStrategy = LoadBalancerStrategy.ROUND_ROBIN + ): + """Initialize load balancer. + + Args: + strategy: Load balancing strategy to use + """ + self.strategy = strategy + self._round_robin_index = 0 + self._lock = asyncio.Lock() + self._in_flight_requests: dict[str, int] = {} + + async def select_endpoint( + self, + endpoints: List[str], + circuit_breaker_registry: Optional["CircuitBreakerRegistry"] = None, + ) -> Optional[str]: + """Select an endpoint using configured strategy. + + Args: + endpoints: List of available endpoint URLs + circuit_breaker_registry: Optional circuit breaker registry to check health + + Returns: + Selected endpoint URL or None if all endpoints are unhealthy + """ + if not endpoints: + return None + + # Filter out unhealthy endpoints if circuit breaker available + healthy_endpoints = endpoints + if circuit_breaker_registry is not None: + from tetra_rp.runtime.circuit_breaker import CircuitState + + healthy_endpoints = [ + url + for url in endpoints + if circuit_breaker_registry.get_state(url) != CircuitState.OPEN + ] + + if not healthy_endpoints: + logger.warning( + f"All {len(endpoints)} endpoints are unhealthy (circuit open)" + ) + return None + + if self.strategy == LoadBalancerStrategy.ROUND_ROBIN: + return await self._round_robin_select(healthy_endpoints) + elif self.strategy == LoadBalancerStrategy.LEAST_CONNECTIONS: + return await self._least_connections_select(healthy_endpoints) + elif self.strategy == LoadBalancerStrategy.RANDOM: + return await self._random_select(healthy_endpoints) + else: + # Default to round-robin + return await self._round_robin_select(healthy_endpoints) + + async def _round_robin_select(self, endpoints: List[str]) -> str: + """Select endpoint using round-robin strategy. + + Args: + endpoints: List of available endpoints + + Returns: + Selected endpoint URL + """ + async with self._lock: + selected = endpoints[self._round_robin_index % len(endpoints)] + self._round_robin_index += 1 + logger.debug( + f"Load balancer: ROUND_ROBIN selected {selected} " + f"(index {self._round_robin_index - 1})" + ) + return selected + + async def _least_connections_select(self, endpoints: List[str]) -> str: + """Select endpoint with fewest in-flight requests. + + Args: + endpoints: List of available endpoints + + Returns: + Selected endpoint URL + """ + async with self._lock: + # Initialize counts for endpoints + for endpoint in endpoints: + if endpoint not in self._in_flight_requests: + self._in_flight_requests[endpoint] = 0 + + # Find endpoint with minimum connections + selected = min(endpoints, key=lambda e: self._in_flight_requests.get(e, 0)) + + logger.debug( + f"Load balancer: LEAST_CONNECTIONS selected {selected} " + f"({self._in_flight_requests.get(selected, 0)} in-flight)" + ) + return selected + + async def _random_select(self, endpoints: List[str]) -> str: + """Select endpoint using random strategy. + + Args: + endpoints: List of available endpoints + + Returns: + Selected endpoint URL + """ + selected = random.choice(endpoints) + logger.debug(f"Load balancer: RANDOM selected {selected}") + return selected + + async def record_request(self, endpoint: str) -> None: + """Record that a request is starting on endpoint. + + Args: + endpoint: Endpoint URL + """ + async with self._lock: + self._in_flight_requests[endpoint] = ( + self._in_flight_requests.get(endpoint, 0) + 1 + ) + + async def record_request_complete(self, endpoint: str) -> None: + """Record that a request completed on endpoint. + + Args: + endpoint: Endpoint URL + """ + async with self._lock: + if endpoint in self._in_flight_requests: + self._in_flight_requests[endpoint] = max( + 0, self._in_flight_requests[endpoint] - 1 + ) + + def get_stats(self) -> dict[str, int]: + """Get current in-flight request counts. + + Returns: + Mapping of endpoint URLs to in-flight request counts + """ + return dict(self._in_flight_requests) diff --git a/src/tetra_rp/runtime/metrics.py b/src/tetra_rp/runtime/metrics.py new file mode 100644 index 00000000..3a530cde --- /dev/null +++ b/src/tetra_rp/runtime/metrics.py @@ -0,0 +1,325 @@ +"""Metrics collection via structured logging for observability.""" + +import logging +from dataclasses import asdict, dataclass +from enum import Enum +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +class MetricType(Enum): + """Types of metrics that can be collected.""" + + COUNTER = "counter" + GAUGE = "gauge" + HISTOGRAM = "histogram" + + +@dataclass +class Metric: + """Representation of a single metric.""" + + metric_type: MetricType + metric_name: str + value: float + labels: Dict[str, Any] + + def to_dict(self) -> Dict[str, Any]: + """Convert metric to dictionary. + + Returns: + Dictionary representation of metric + """ + return asdict(self) + + +class MetricsCollector: + """Collect metrics via structured logging.""" + + def __init__(self, namespace: str = "tetra.metrics", enabled: bool = True): + """Initialize metrics collector. + + Args: + namespace: Namespace for metrics (used in structured logging) + enabled: Whether metrics collection is enabled + """ + self.namespace = namespace + self.enabled = enabled + + def counter( + self, + name: str, + value: float = 1.0, + labels: Optional[Dict[str, Any]] = None, + ) -> None: + """Record a counter metric (cumulative). + + Args: + name: Name of the metric + value: Value to add to counter (default: 1.0) + labels: Optional labels/tags for the metric + """ + if not self.enabled: + return + + metric = Metric(MetricType.COUNTER, name, value, labels or {}) + self._emit(metric) + + def gauge( + self, + name: str, + value: float, + labels: Optional[Dict[str, Any]] = None, + ) -> None: + """Record a gauge metric (point-in-time value). + + Args: + name: Name of the metric + value: Current value of the gauge + labels: Optional labels/tags for the metric + """ + if not self.enabled: + return + + metric = Metric(MetricType.GAUGE, name, value, labels or {}) + self._emit(metric) + + def histogram( + self, + name: str, + value: float, + labels: Optional[Dict[str, Any]] = None, + ) -> None: + """Record a histogram metric (distribution). + + Args: + name: Name of the metric + value: Value to add to histogram + labels: Optional labels/tags for the metric + """ + if not self.enabled: + return + + metric = Metric(MetricType.HISTOGRAM, name, value, labels or {}) + self._emit(metric) + + def _emit(self, metric: Metric) -> None: + """Emit metric via structured logging. + + Args: + metric: Metric to emit + """ + try: + logger.info( + f"[METRIC] {metric.metric_name}={metric.value}", + extra={ + "namespace": self.namespace, + "metric": metric.to_dict(), + }, + ) + except Exception as e: + logger.error(f"Failed to emit metric {metric.metric_name}: {e}") + + +# Global metrics collector instance +_collector: Optional[MetricsCollector] = None + + +def get_metrics_collector( + namespace: str = "tetra.metrics", enabled: bool = True +) -> MetricsCollector: + """Get global metrics collector (lazy-loaded). + + Args: + namespace: Namespace for metrics + enabled: Whether metrics collection is enabled + + Returns: + MetricsCollector instance + """ + global _collector + if _collector is None: + _collector = MetricsCollector(namespace=namespace, enabled=enabled) + return _collector + + +def set_metrics_collector(collector: MetricsCollector) -> None: + """Set global metrics collector (for testing). + + Args: + collector: MetricsCollector instance + """ + global _collector + _collector = collector + + +class CircuitBreakerMetrics: + """Helper for emitting circuit breaker metrics.""" + + def __init__(self, collector: Optional[MetricsCollector] = None): + """Initialize circuit breaker metrics helper. + + Args: + collector: Optional MetricsCollector instance (uses global if not provided) + """ + self.collector = collector or get_metrics_collector() + + def state_changed( + self, endpoint_url: str, new_state: str, previous_state: str + ) -> None: + """Emit metric when circuit breaker state changes. + + Args: + endpoint_url: URL of the endpoint + new_state: New circuit state + previous_state: Previous circuit state + """ + self.collector.counter( + "circuit_breaker_state_changes", + value=1.0, + labels={ + "endpoint_url": endpoint_url, + "new_state": new_state, + "previous_state": previous_state, + }, + ) + + def endpoint_requests(self, endpoint_url: str, status: str, count: int = 1) -> None: + """Emit metric for endpoint requests. + + Args: + endpoint_url: URL of the endpoint + status: Request status (success, failure, etc.) + count: Number of requests + """ + self.collector.counter( + "endpoint_requests", + value=float(count), + labels={"endpoint_url": endpoint_url, "status": status}, + ) + + def endpoint_latency(self, endpoint_url: str, latency_ms: float) -> None: + """Emit metric for endpoint latency. + + Args: + endpoint_url: URL of the endpoint + latency_ms: Latency in milliseconds + """ + self.collector.histogram( + "endpoint_latency", + value=latency_ms, + labels={"endpoint_url": endpoint_url}, + ) + + def in_flight_requests(self, endpoint_url: str, count: int) -> None: + """Emit metric for in-flight requests. + + Args: + endpoint_url: URL of the endpoint + count: Current number of in-flight requests + """ + self.collector.gauge( + "in_flight_requests", + value=float(count), + labels={"endpoint_url": endpoint_url}, + ) + + +class RetryMetrics: + """Helper for emitting retry metrics.""" + + def __init__(self, collector: Optional[MetricsCollector] = None): + """Initialize retry metrics helper. + + Args: + collector: Optional MetricsCollector instance (uses global if not provided) + """ + self.collector = collector or get_metrics_collector() + + def retry_attempt( + self, function_name: str, attempt: int, error: Optional[str] = None + ) -> None: + """Emit metric for retry attempt. + + Args: + function_name: Name of the function being retried + attempt: Attempt number + error: Optional error message + """ + labels = { + "function_name": function_name, + "attempt": str(attempt), + } + if error: + labels["error"] = error + + self.collector.counter( + "retry_attempts", + value=1.0, + labels=labels, + ) + + def retry_success(self, function_name: str, total_attempts: int) -> None: + """Emit metric for successful retry. + + Args: + function_name: Name of the function + total_attempts: Total attempts made before success + """ + self.collector.counter( + "retry_success", + value=1.0, + labels={ + "function_name": function_name, + "attempts": str(total_attempts), + }, + ) + + def retry_exhausted(self, function_name: str, max_attempts: int) -> None: + """Emit metric when max retries exceeded. + + Args: + function_name: Name of the function + max_attempts: Maximum attempts configured + """ + self.collector.counter( + "retry_exhausted", + value=1.0, + labels={ + "function_name": function_name, + "max_attempts": str(max_attempts), + }, + ) + + +class LoadBalancerMetrics: + """Helper for emitting load balancer metrics.""" + + def __init__(self, collector: Optional[MetricsCollector] = None): + """Initialize load balancer metrics helper. + + Args: + collector: Optional MetricsCollector instance (uses global if not provided) + """ + self.collector = collector or get_metrics_collector() + + def endpoint_selected( + self, strategy: str, endpoint_url: str, total_candidates: int + ) -> None: + """Emit metric when endpoint is selected. + + Args: + strategy: Load balancing strategy used + endpoint_url: Selected endpoint URL + total_candidates: Total candidate endpoints + """ + self.collector.counter( + "load_balancer_selection", + value=1.0, + labels={ + "strategy": strategy, + "endpoint_url": endpoint_url, + "candidates": str(total_candidates), + }, + ) diff --git a/src/tetra_rp/runtime/reliability_config.py b/src/tetra_rp/runtime/reliability_config.py new file mode 100644 index 00000000..2bd063e6 --- /dev/null +++ b/src/tetra_rp/runtime/reliability_config.py @@ -0,0 +1,148 @@ +"""Centralized configuration for reliability features.""" + +import asyncio +import os +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + + +class LoadBalancerStrategy(Enum): + """Load balancing strategies for endpoint selection.""" + + ROUND_ROBIN = "round_robin" + LEAST_CONNECTIONS = "least_connections" + RANDOM = "random" + + +@dataclass +class CircuitBreakerConfig: + """Configuration for circuit breaker behavior.""" + + enabled: bool = True + failure_threshold: int = 5 + success_threshold: int = 2 + timeout_seconds: int = 60 + window_size: int = 10 + + +@dataclass +class LoadBalancerConfig: + """Configuration for load balancer behavior.""" + + enabled: bool = False + strategy: LoadBalancerStrategy = LoadBalancerStrategy.ROUND_ROBIN + + +@dataclass +class RetryConfig: + """Configuration for retry behavior with exponential backoff.""" + + enabled: bool = True + max_attempts: int = 3 + base_delay: float = 0.5 + max_delay: float = 10.0 + jitter: float = 0.2 + retryable_exceptions: tuple = field(default=(asyncio.TimeoutError, ConnectionError)) + retryable_status_codes: set = field( + default_factory=lambda: {408, 429, 500, 502, 503, 504} + ) + + +@dataclass +class MetricsConfig: + """Configuration for metrics collection.""" + + enabled: bool = True + namespace: str = "tetra.metrics" + + +@dataclass +class ReliabilityConfig: + """Centralized reliability features configuration.""" + + circuit_breaker: CircuitBreakerConfig = field(default_factory=CircuitBreakerConfig) + load_balancer: LoadBalancerConfig = field(default_factory=LoadBalancerConfig) + retry: RetryConfig = field(default_factory=RetryConfig) + metrics: MetricsConfig = field(default_factory=MetricsConfig) + + @classmethod + def from_env(cls) -> "ReliabilityConfig": + """Load configuration from environment variables. + + Environment variables: + - TETRA_CIRCUIT_BREAKER_ENABLED: Enable circuit breaker (default: true) + - TETRA_CB_FAILURE_THRESHOLD: Failures before opening (default: 5) + - TETRA_CB_SUCCESS_THRESHOLD: Successes to close (default: 2) + - TETRA_CB_TIMEOUT_SECONDS: Time before half-open (default: 60) + - TETRA_LOAD_BALANCER_ENABLED: Enable load balancer (default: false) + - TETRA_LB_STRATEGY: Load balancer strategy (default: round_robin) + - TETRA_RETRY_ENABLED: Enable retry (default: true) + - TETRA_RETRY_MAX_ATTEMPTS: Max retry attempts (default: 3) + - TETRA_RETRY_BASE_DELAY: Base delay for backoff (default: 0.5) + - TETRA_METRICS_ENABLED: Enable metrics (default: true) + + Returns: + ReliabilityConfig initialized from environment variables. + """ + circuit_breaker = CircuitBreakerConfig( + enabled=os.getenv("TETRA_CIRCUIT_BREAKER_ENABLED", "true").lower() + == "true", + failure_threshold=int(os.getenv("TETRA_CB_FAILURE_THRESHOLD", "5")), + success_threshold=int(os.getenv("TETRA_CB_SUCCESS_THRESHOLD", "2")), + timeout_seconds=int(os.getenv("TETRA_CB_TIMEOUT_SECONDS", "60")), + ) + + strategy_str = os.getenv("TETRA_LB_STRATEGY", "round_robin").lower() + try: + strategy = LoadBalancerStrategy(strategy_str) + except ValueError: + strategy = LoadBalancerStrategy.ROUND_ROBIN + + load_balancer = LoadBalancerConfig( + enabled=os.getenv("TETRA_LOAD_BALANCER_ENABLED", "false").lower() == "true", + strategy=strategy, + ) + + retry = RetryConfig( + enabled=os.getenv("TETRA_RETRY_ENABLED", "true").lower() == "true", + max_attempts=int(os.getenv("TETRA_RETRY_MAX_ATTEMPTS", "3")), + base_delay=float(os.getenv("TETRA_RETRY_BASE_DELAY", "0.5")), + ) + + metrics = MetricsConfig( + enabled=os.getenv("TETRA_METRICS_ENABLED", "true").lower() == "true", + ) + + return cls( + circuit_breaker=circuit_breaker, + load_balancer=load_balancer, + retry=retry, + metrics=metrics, + ) + + +# Global default configuration +_config: Optional[ReliabilityConfig] = None + + +def get_reliability_config() -> ReliabilityConfig: + """Get global reliability configuration (lazy-loaded). + + Returns: + ReliabilityConfig instance initialized from environment. + """ + global _config + if _config is None: + _config = ReliabilityConfig.from_env() + return _config + + +def set_reliability_config(config: ReliabilityConfig) -> None: + """Set global reliability configuration (for testing). + + Args: + config: ReliabilityConfig to set as global. + """ + global _config + _config = config diff --git a/src/tetra_rp/runtime/retry_manager.py b/src/tetra_rp/runtime/retry_manager.py new file mode 100644 index 00000000..7435ca22 --- /dev/null +++ b/src/tetra_rp/runtime/retry_manager.py @@ -0,0 +1,118 @@ +"""Retry logic with exponential backoff for failed remote calls.""" + +import asyncio +import logging +from typing import Any, Callable, Optional, Set, Tuple, Type + +from tetra_rp.core.utils.backoff import get_backoff_delay + +logger = logging.getLogger(__name__) + + +class RetryExhaustedError(Exception): + """Raised when max retry attempts are exceeded.""" + + pass + + +async def retry_with_backoff( + func: Callable[..., Any], + max_attempts: int = 3, + base_delay: float = 0.5, + max_delay: float = 10.0, + jitter: float = 0.2, + retryable_exceptions: Optional[Tuple[Type[Exception], ...]] = None, + retryable_status_codes: Optional[Set[int]] = None, + circuit_breaker: Optional[Any] = None, + *args: Any, + **kwargs: Any, +) -> Any: + """Execute async function with retry and exponential backoff. + + Args: + func: Async function to execute + max_attempts: Maximum number of attempts (default: 3) + base_delay: Base delay between retries in seconds (default: 0.5) + max_delay: Maximum delay between retries (default: 10.0) + jitter: Jitter factor (0.0-1.0) to add randomness (default: 0.2) + retryable_exceptions: Tuple of exception types to retry on + (default: (asyncio.TimeoutError, ConnectionError)) + retryable_status_codes: Set of HTTP status codes to retry on + (default: {408, 429, 500, 502, 503, 504}) + circuit_breaker: Optional circuit breaker to check before retry + *args: Positional arguments for func + **kwargs: Keyword arguments for func + + Returns: + Result from successful function call + + Raises: + RetryExhaustedError: If max attempts exceeded + Exception: If non-retryable exception occurs + """ + if retryable_exceptions is None: + retryable_exceptions = (asyncio.TimeoutError, ConnectionError) + + if retryable_status_codes is None: + retryable_status_codes = {408, 429, 500, 502, 503, 504} + + last_exception: Optional[Exception] = None + + for attempt in range(max_attempts): + try: + # Check circuit breaker before attempting + if circuit_breaker is not None: + from tetra_rp.runtime.circuit_breaker import CircuitState + + if circuit_breaker.get_state() == CircuitState.OPEN: + raise RuntimeError( + f"Circuit breaker OPEN, skipping retry attempt {attempt + 1}" + ) + + result = await func(*args, **kwargs) + + # Log success on retry + if attempt > 0: + logger.info(f"Retry succeeded on attempt {attempt + 1}/{max_attempts}") + + return result + + except Exception as e: + last_exception = e + + # Check if exception is retryable + if not isinstance(e, retryable_exceptions): + logger.debug( + f"Non-retryable exception in {func.__name__}: {type(e).__name__}" + ) + raise + + # Check for retryable status codes (if exception has status_code) + if hasattr(e, "status_code"): + if e.status_code not in retryable_status_codes: # type: ignore + logger.debug( + f"Non-retryable status code {e.status_code} in {func.__name__}" + ) + raise + + # If this is the last attempt, don't retry + if attempt >= max_attempts - 1: + logger.warning( + f"Max retries ({max_attempts}) exhausted for {func.__name__}" + ) + raise RetryExhaustedError( + f"Failed after {max_attempts} attempts: {e}" + ) from e + + # Calculate delay with exponential backoff and jitter + delay = get_backoff_delay(attempt, base_delay, max_delay, jitter=jitter) + logger.debug( + f"Retry {attempt + 1}/{max_attempts} for {func.__name__} " + f"after {delay:.2f}s" + ) + await asyncio.sleep(delay) + + # Should never reach here, but handle edge case + if last_exception: + raise last_exception + raise RetryExhaustedError(f"Failed after {max_attempts} attempts") diff --git a/tests/unit/runtime/test_circuit_breaker.py b/tests/unit/runtime/test_circuit_breaker.py new file mode 100644 index 00000000..39d405ce --- /dev/null +++ b/tests/unit/runtime/test_circuit_breaker.py @@ -0,0 +1,188 @@ +"""Tests for circuit breaker module.""" + +import asyncio + +import pytest + +from tetra_rp.runtime.circuit_breaker import ( + CircuitBreakerOpenError, + CircuitState, + EndpointCircuitBreaker, +) + + +class TestCircuitState: + """Test CircuitState enum.""" + + def test_states(self): + """Test circuit states exist.""" + assert CircuitState.CLOSED.value == "closed" + assert CircuitState.OPEN.value == "open" + assert CircuitState.HALF_OPEN.value == "half_open" + + +class TestEndpointCircuitBreaker: + """Test EndpointCircuitBreaker class.""" + + @pytest.mark.asyncio + async def test_successful_execution(self): + """Test successful function execution.""" + breaker = EndpointCircuitBreaker( + "http://example.com", + failure_threshold=5, + timeout_seconds=60, + ) + + async def success_func(): + return "success" + + result = await breaker.execute(success_func) + assert result == "success" + assert breaker.get_state() == CircuitState.CLOSED + + @pytest.mark.asyncio + async def test_failed_execution_within_threshold(self): + """Test failed execution within threshold.""" + breaker = EndpointCircuitBreaker( + "http://example.com", + failure_threshold=5, + timeout_seconds=60, + ) + + async def failing_func(): + raise ConnectionError("Connection failed") + + for _ in range(4): # 4 failures, threshold is 5 + with pytest.raises(ConnectionError): + await breaker.execute(failing_func) + assert breaker.get_state() == CircuitState.CLOSED + + @pytest.mark.asyncio + async def test_circuit_opens_at_threshold(self): + """Test circuit opens when failure threshold reached.""" + breaker = EndpointCircuitBreaker( + "http://example.com", + failure_threshold=3, + timeout_seconds=60, + ) + + async def failing_func(): + raise ConnectionError("Connection failed") + + # Reach threshold + for _ in range(3): + with pytest.raises(ConnectionError): + await breaker.execute(failing_func) + + # Circuit should be OPEN now + assert breaker.get_state() == CircuitState.OPEN + + # Further requests should fail immediately + with pytest.raises(CircuitBreakerOpenError): + await breaker.execute(failing_func) + + @pytest.mark.asyncio + async def test_circuit_half_open_after_timeout(self): + """Test circuit transitions to HALF_OPEN after timeout.""" + breaker = EndpointCircuitBreaker( + "http://example.com", + failure_threshold=2, + timeout_seconds=1, + ) + + async def failing_func(): + raise ConnectionError("Connection failed") + + # Open circuit + for _ in range(2): + with pytest.raises(ConnectionError): + await breaker.execute(failing_func) + + assert breaker.get_state() == CircuitState.OPEN + + # Wait for timeout + await asyncio.sleep(1.1) + + # Next attempt should transition to HALF_OPEN + async def success_func(): + return "recovered" + + await breaker.execute(success_func) + # First success in HALF_OPEN doesn't close circuit yet + assert breaker.get_state() == CircuitState.HALF_OPEN + + @pytest.mark.asyncio + async def test_circuit_closes_after_success_threshold(self): + """Test circuit closes after enough successes.""" + breaker = EndpointCircuitBreaker( + "http://example.com", + failure_threshold=2, + success_threshold=2, + timeout_seconds=1, + ) + + async def failing_func(): + raise ConnectionError("Connection failed") + + # Open circuit + for _ in range(2): + with pytest.raises(ConnectionError): + await breaker.execute(failing_func) + + assert breaker.get_state() == CircuitState.OPEN + + # Wait for timeout + await asyncio.sleep(1.1) + + # Succeed enough times to close circuit + async def success_func(): + return "success" + + for _ in range(2): + result = await breaker.execute(success_func) + assert result == "success" + + assert breaker.get_state() == CircuitState.CLOSED + + @pytest.mark.asyncio + async def test_get_stats(self): + """Test getting circuit breaker statistics.""" + breaker = EndpointCircuitBreaker("http://example.com") + + async def success_func(): + return "ok" + + await breaker.execute(success_func) + stats = breaker.get_stats() + assert stats.success_count == 1 + assert stats.failure_count == 0 + assert stats.state == CircuitState.CLOSED + assert stats.total_requests == 1 + + @pytest.mark.asyncio + async def test_half_open_resets_on_failure(self): + """Test that failure in HALF_OPEN opens circuit again.""" + breaker = EndpointCircuitBreaker( + "http://example.com", + failure_threshold=1, + timeout_seconds=1, + ) + + async def failing_func(): + raise ConnectionError("Connection failed") + + # Open circuit + with pytest.raises(ConnectionError): + await breaker.execute(failing_func) + + assert breaker.get_state() == CircuitState.OPEN + + # Wait for timeout + await asyncio.sleep(1.1) + + # Attempt recovery, should transition to HALF_OPEN + with pytest.raises(ConnectionError): + await breaker.execute(failing_func) + + # Should transition back to OPEN on first failure + assert breaker.get_state() == CircuitState.OPEN diff --git a/tests/unit/runtime/test_load_balancer.py b/tests/unit/runtime/test_load_balancer.py new file mode 100644 index 00000000..335c138d --- /dev/null +++ b/tests/unit/runtime/test_load_balancer.py @@ -0,0 +1,145 @@ +"""Tests for load balancer module.""" + +import pytest + +from tetra_rp.runtime.load_balancer import LoadBalancer +from tetra_rp.runtime.reliability_config import LoadBalancerStrategy + + +class TestLoadBalancer: + """Test LoadBalancer class.""" + + def test_round_robin_selection(self): + """Test round-robin endpoint selection.""" + lb = LoadBalancer(strategy=LoadBalancerStrategy.ROUND_ROBIN) + endpoints = ["http://a.com", "http://b.com", "http://c.com"] + + selected = [] + for _ in range(9): + endpoint = lb._round_robin_index + selected_ep = endpoints[endpoint % len(endpoints)] + lb._round_robin_index += 1 + selected.append(selected_ep) + + # Should cycle through endpoints + assert selected[0] == "http://a.com" + assert selected[1] == "http://b.com" + assert selected[2] == "http://c.com" + assert selected[3] == "http://a.com" + + @pytest.mark.asyncio + async def test_select_endpoint_round_robin(self): + """Test select_endpoint with round-robin.""" + lb = LoadBalancer(strategy=LoadBalancerStrategy.ROUND_ROBIN) + endpoints = ["http://a.com", "http://b.com"] + + selected1 = await lb.select_endpoint(endpoints) + selected2 = await lb.select_endpoint(endpoints) + selected3 = await lb.select_endpoint(endpoints) + + assert selected1 == "http://a.com" + assert selected2 == "http://b.com" + assert selected3 == "http://a.com" + + @pytest.mark.asyncio + async def test_select_endpoint_random(self): + """Test select_endpoint with random strategy.""" + lb = LoadBalancer(strategy=LoadBalancerStrategy.RANDOM) + endpoints = ["http://a.com", "http://b.com"] + + selected = await lb.select_endpoint(endpoints) + assert selected in endpoints + + @pytest.mark.asyncio + async def test_select_endpoint_least_connections(self): + """Test select_endpoint with least connections strategy.""" + lb = LoadBalancer(strategy=LoadBalancerStrategy.LEAST_CONNECTIONS) + endpoints = ["http://a.com", "http://b.com"] + + await lb.record_request(endpoints[0]) + await lb.record_request(endpoints[0]) + + selected = await lb.select_endpoint(endpoints) + assert selected == endpoints[1] + + @pytest.mark.asyncio + async def test_empty_endpoints_returns_none(self): + """Test that empty endpoint list returns None.""" + lb = LoadBalancer() + selected = await lb.select_endpoint([]) + assert selected is None + + @pytest.mark.asyncio + async def test_record_request_and_complete(self): + """Test recording in-flight requests.""" + lb = LoadBalancer() + endpoint = "http://a.com" + + await lb.record_request(endpoint) + stats = lb.get_stats() + assert stats[endpoint] == 1 + + await lb.record_request(endpoint) + stats = lb.get_stats() + assert stats[endpoint] == 2 + + await lb.record_request_complete(endpoint) + stats = lb.get_stats() + assert stats[endpoint] == 1 + + @pytest.mark.asyncio + async def test_record_request_complete_does_not_go_negative(self): + """Test that in-flight count doesn't go negative.""" + lb = LoadBalancer() + endpoint = "http://a.com" + + await lb.record_request_complete(endpoint) + stats = lb.get_stats() + assert stats.get(endpoint, 0) == 0 + + @pytest.mark.asyncio + async def test_select_endpoint_with_circuit_breaker(self): + """Test select_endpoint filters unhealthy endpoints.""" + + class MockCircuitBreaker: + def __init__(self, open_endpoints): + self.open_endpoints = open_endpoints + + def get_state(self, endpoint): + from tetra_rp.runtime.circuit_breaker import CircuitState + + if endpoint in self.open_endpoints: + return CircuitState.OPEN + return CircuitState.CLOSED + + lb = LoadBalancer(strategy=LoadBalancerStrategy.ROUND_ROBIN) + endpoints = ["http://a.com", "http://b.com", "http://c.com"] + circuit_breaker = MockCircuitBreaker({"http://a.com"}) + + # Should skip the open endpoint + selected = await lb.select_endpoint(endpoints, circuit_breaker) + assert selected != "http://a.com" + + @pytest.mark.asyncio + async def test_all_endpoints_unhealthy_returns_none(self): + """Test that all unhealthy endpoints returns None.""" + + class MockCircuitBreaker: + def get_state(self, endpoint): + from tetra_rp.runtime.circuit_breaker import CircuitState + + return CircuitState.OPEN + + lb = LoadBalancer() + endpoints = ["http://a.com", "http://b.com"] + circuit_breaker = MockCircuitBreaker() + + selected = await lb.select_endpoint(endpoints, circuit_breaker) + assert selected is None + + def test_get_stats(self): + """Test getting load balancer statistics.""" + lb = LoadBalancer() + stats = lb.get_stats() + assert isinstance(stats, dict) + assert len(stats) == 0 diff --git a/tests/unit/runtime/test_reliability_config.py b/tests/unit/runtime/test_reliability_config.py new file mode 100644 index 00000000..52473bf3 --- /dev/null +++ b/tests/unit/runtime/test_reliability_config.py @@ -0,0 +1,181 @@ +"""Tests for reliability configuration module.""" + +from tetra_rp.runtime.reliability_config import ( + CircuitBreakerConfig, + LoadBalancerConfig, + LoadBalancerStrategy, + MetricsConfig, + ReliabilityConfig, + RetryConfig, + get_reliability_config, +) + + +class TestCircuitBreakerConfig: + """Test CircuitBreakerConfig dataclass.""" + + def test_defaults(self): + """Test default values.""" + config = CircuitBreakerConfig() + assert config.enabled is True + assert config.failure_threshold == 5 + assert config.success_threshold == 2 + assert config.timeout_seconds == 60 + assert config.window_size == 10 + + def test_custom_values(self): + """Test with custom values.""" + config = CircuitBreakerConfig( + enabled=False, + failure_threshold=10, + success_threshold=3, + timeout_seconds=30, + window_size=20, + ) + assert config.enabled is False + assert config.failure_threshold == 10 + assert config.success_threshold == 3 + assert config.timeout_seconds == 30 + assert config.window_size == 20 + + +class TestLoadBalancerConfig: + """Test LoadBalancerConfig dataclass.""" + + def test_defaults(self): + """Test default values.""" + config = LoadBalancerConfig() + assert config.enabled is False + assert config.strategy == LoadBalancerStrategy.ROUND_ROBIN + + def test_custom_values(self): + """Test with custom values.""" + config = LoadBalancerConfig( + enabled=True, + strategy=LoadBalancerStrategy.LEAST_CONNECTIONS, + ) + assert config.enabled is True + assert config.strategy == LoadBalancerStrategy.LEAST_CONNECTIONS + + +class TestRetryConfig: + """Test RetryConfig dataclass.""" + + def test_defaults(self): + """Test default values.""" + config = RetryConfig() + assert config.enabled is True + assert config.max_attempts == 3 + assert config.base_delay == 0.5 + assert config.max_delay == 10.0 + assert config.jitter == 0.2 + assert 408 in config.retryable_status_codes + assert 500 in config.retryable_status_codes + + def test_custom_values(self): + """Test with custom values.""" + config = RetryConfig( + enabled=False, + max_attempts=5, + base_delay=1.0, + max_delay=20.0, + jitter=0.1, + ) + assert config.enabled is False + assert config.max_attempts == 5 + assert config.base_delay == 1.0 + assert config.max_delay == 20.0 + assert config.jitter == 0.1 + + +class TestMetricsConfig: + """Test MetricsConfig dataclass.""" + + def test_defaults(self): + """Test default values.""" + config = MetricsConfig() + assert config.enabled is True + assert config.namespace == "tetra.metrics" + + def test_custom_values(self): + """Test with custom values.""" + config = MetricsConfig(enabled=False, namespace="custom.metrics") + assert config.enabled is False + assert config.namespace == "custom.metrics" + + +class TestReliabilityConfig: + """Test ReliabilityConfig dataclass.""" + + def test_defaults(self): + """Test default values.""" + config = ReliabilityConfig() + assert config.circuit_breaker is not None + assert config.load_balancer is not None + assert config.retry is not None + assert config.metrics is not None + assert config.circuit_breaker.enabled is True + assert config.load_balancer.enabled is False + assert config.retry.enabled is True + assert config.metrics.enabled is True + + def test_custom_nested_configs(self): + """Test with custom nested configurations.""" + cb_config = CircuitBreakerConfig(enabled=False) + lb_config = LoadBalancerConfig(enabled=True) + config = ReliabilityConfig( + circuit_breaker=cb_config, + load_balancer=lb_config, + ) + assert config.circuit_breaker.enabled is False + assert config.load_balancer.enabled is True + + def test_from_env_default(self, monkeypatch): + """Test from_env with no environment variables.""" + monkeypatch.delenv("TETRA_CIRCUIT_BREAKER_ENABLED", raising=False) + config = ReliabilityConfig.from_env() + assert config.circuit_breaker.enabled is True + assert config.load_balancer.enabled is False + assert config.retry.enabled is True + + def test_from_env_custom(self, monkeypatch): + """Test from_env with custom environment variables.""" + monkeypatch.setenv("TETRA_CIRCUIT_BREAKER_ENABLED", "false") + monkeypatch.setenv("TETRA_LOAD_BALANCER_ENABLED", "true") + monkeypatch.setenv("TETRA_CB_FAILURE_THRESHOLD", "10") + config = ReliabilityConfig.from_env() + assert config.circuit_breaker.enabled is False + assert config.load_balancer.enabled is True + assert config.circuit_breaker.failure_threshold == 10 + + def test_from_env_load_balancer_strategy(self, monkeypatch): + """Test from_env with load balancer strategy.""" + monkeypatch.setenv("TETRA_LB_STRATEGY", "least_connections") + config = ReliabilityConfig.from_env() + assert config.load_balancer.strategy == LoadBalancerStrategy.LEAST_CONNECTIONS + + def test_from_env_invalid_strategy_defaults(self, monkeypatch): + """Test from_env with invalid strategy defaults to round_robin.""" + monkeypatch.setenv("TETRA_LB_STRATEGY", "invalid_strategy") + config = ReliabilityConfig.from_env() + assert config.load_balancer.strategy == LoadBalancerStrategy.ROUND_ROBIN + + +class TestLoadBalancerStrategy: + """Test LoadBalancerStrategy enum.""" + + def test_strategy_values(self): + """Test that strategies have correct values.""" + assert LoadBalancerStrategy.ROUND_ROBIN.value == "round_robin" + assert LoadBalancerStrategy.LEAST_CONNECTIONS.value == "least_connections" + assert LoadBalancerStrategy.RANDOM.value == "random" + + +class TestGlobalConfig: + """Test global configuration accessor.""" + + def test_get_reliability_config(self): + """Test getting global reliability config.""" + config = get_reliability_config() + assert isinstance(config, ReliabilityConfig) + assert config.circuit_breaker is not None diff --git a/tests/unit/runtime/test_retry_manager.py b/tests/unit/runtime/test_retry_manager.py new file mode 100644 index 00000000..122cf902 --- /dev/null +++ b/tests/unit/runtime/test_retry_manager.py @@ -0,0 +1,165 @@ +"""Tests for retry manager module.""" + +import asyncio + +import pytest + +from tetra_rp.runtime.retry_manager import RetryExhaustedError, retry_with_backoff + + +class TestRetryWithBackoff: + """Test retry_with_backoff function.""" + + @pytest.mark.asyncio + async def test_successful_first_attempt(self): + """Test successful execution on first attempt.""" + + async def success_func(): + return "success" + + result = await retry_with_backoff(success_func, max_attempts=3) + assert result == "success" + + @pytest.mark.asyncio + async def test_non_retryable_exception_raises_immediately(self): + """Test that non-retryable exceptions raise immediately.""" + + async def failing_func(): + raise ValueError("Non-retryable error") + + with pytest.raises(ValueError): + await retry_with_backoff(failing_func, max_attempts=3) + + @pytest.mark.asyncio + async def test_retryable_exception_retries(self): + """Test that retryable exceptions are retried.""" + attempt_count = 0 + + async def failing_then_success(): + nonlocal attempt_count + attempt_count += 1 + if attempt_count < 2: + raise ConnectionError("Connection failed") + return "success" + + result = await retry_with_backoff( + failing_then_success, + max_attempts=3, + base_delay=0.01, + max_delay=0.1, + ) + assert result == "success" + assert attempt_count == 2 + + @pytest.mark.asyncio + async def test_max_retries_exhausted(self): + """Test that RetryExhaustedError is raised after max attempts.""" + + async def always_fails(): + raise ConnectionError("Always fails") + + with pytest.raises(RetryExhaustedError): + await retry_with_backoff( + always_fails, + max_attempts=2, + base_delay=0.01, + max_delay=0.1, + ) + + @pytest.mark.asyncio + async def test_timeout_is_retryable(self): + """Test that asyncio.TimeoutError is retried by default.""" + attempt_count = 0 + + async def timeout_then_success(): + nonlocal attempt_count + attempt_count += 1 + if attempt_count < 2: + raise asyncio.TimeoutError("Request timed out") + return "success" + + result = await retry_with_backoff( + timeout_then_success, + max_attempts=3, + base_delay=0.01, + max_delay=0.1, + ) + assert result == "success" + + @pytest.mark.asyncio + async def test_custom_retryable_exceptions(self): + """Test with custom retryable exceptions.""" + + class CustomError(Exception): + pass + + attempt_count = 0 + + async def custom_error_then_success(): + nonlocal attempt_count + attempt_count += 1 + if attempt_count < 2: + raise CustomError("Custom error") + return "success" + + result = await retry_with_backoff( + custom_error_then_success, + max_attempts=3, + retryable_exceptions=(CustomError,), + base_delay=0.01, + max_delay=0.1, + ) + assert result == "success" + + @pytest.mark.asyncio + async def test_exponential_backoff(self): + """Test that backoff increases exponentially.""" + attempt_times = [] + + async def track_attempts(): + attempt_times.append(asyncio.get_event_loop().time()) + if len(attempt_times) < 3: + raise ConnectionError("Failed") + return "success" + + result = await retry_with_backoff( + track_attempts, + max_attempts=3, + base_delay=0.05, + max_delay=1.0, + jitter=0.0, # No jitter for predictable timing + ) + assert result == "success" + # Should have at least 3 attempts with delays between them + assert len(attempt_times) == 3 + + @pytest.mark.asyncio + async def test_with_args_and_kwargs(self): + """Test retry with function arguments.""" + + async def add(a, b): + return a + b + + result = await retry_with_backoff(add, max_attempts=1, a=2, b=3) + assert result == 5 + + @pytest.mark.asyncio + async def test_retry_with_circuit_breaker_open(self): + """Test that open circuit breaker prevents retries.""" + + class MockCircuitBreaker: + def get_state(self): + from tetra_rp.runtime.circuit_breaker import CircuitState + + return CircuitState.OPEN + + async def failing_func(): + raise ConnectionError("Failed") + + with pytest.raises(RuntimeError, match="Circuit breaker OPEN"): + await retry_with_backoff( + failing_func, + max_attempts=3, + circuit_breaker=MockCircuitBreaker(), + base_delay=0.01, + ) From f088118c51cfbdb1a0ce53ecf37f9d11bb765646 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 22 Jan 2026 12:31:35 -0800 Subject: [PATCH 06/11] chore: merge correction --- src/tetra_rp/core/utils/json.py | 46 ----- tests/unit/core/utils/test_json.py | 262 ----------------------------- 2 files changed, 308 deletions(-) delete mode 100644 src/tetra_rp/core/utils/json.py delete mode 100644 tests/unit/core/utils/test_json.py diff --git a/src/tetra_rp/core/utils/json.py b/src/tetra_rp/core/utils/json.py deleted file mode 100644 index 634c1dbf..00000000 --- a/src/tetra_rp/core/utils/json.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Utilities for normalizing data structures for JSON serialization.""" - -from enum import Enum -from typing import Any - -from pydantic import BaseModel - - -def normalize_for_json(obj: Any) -> Any: - """Normalize an object for JSON serialization. - - Converts Pydantic models to dicts and Enum values to their values, - while recursively processing collections. - - Args: - obj: The object to normalize. - - Returns: - A JSON-serializable version of the object. - """ - # Handle primitives - if obj is None or isinstance(obj, (bool, int, float, str)): - return obj - - # Handle Enum - if isinstance(obj, Enum): - return obj.value - - # Handle Pydantic BaseModel - if isinstance(obj, BaseModel): - return normalize_for_json(obj.model_dump()) - - # Handle dict - if isinstance(obj, dict): - return {key: normalize_for_json(value) for key, value in obj.items()} - - # Handle tuple - if isinstance(obj, tuple): - return tuple(normalize_for_json(item) for item in obj) - - # Handle list - if isinstance(obj, list): - return [normalize_for_json(item) for item in obj] - - # For any other type, return as-is - return obj diff --git a/tests/unit/core/utils/test_json.py b/tests/unit/core/utils/test_json.py deleted file mode 100644 index 34f2e241..00000000 --- a/tests/unit/core/utils/test_json.py +++ /dev/null @@ -1,262 +0,0 @@ -"""Tests for JSON normalization utilities.""" - -from enum import Enum - -from pydantic import BaseModel - -from tetra_rp.core.utils.json import normalize_for_json - - -class Color(Enum): - """Test enum for color values.""" - - RED = "red" - GREEN = "green" - BLUE = "blue" - - -class Status(Enum): - """Test enum for status values.""" - - ACTIVE = 1 - INACTIVE = 0 - - -class SimpleModel(BaseModel): - """Simple Pydantic model for testing.""" - - name: str - value: int - - -class NestedModel(BaseModel): - """Nested Pydantic model for testing.""" - - simple: SimpleModel - color: Color - - -class ModelWithList(BaseModel): - """Model containing a list.""" - - items: list[int] - tags: list[str] - - -class ModelWithEnum(BaseModel): - """Model with enum field.""" - - status: Status - color: Color - - -class TestNormalizeForJsonPrimitives: - """Test normalize_for_json with primitive types.""" - - def test_integer_passthrough(self) -> None: - """Test integer passes through unchanged.""" - assert normalize_for_json(42) == 42 - assert normalize_for_json(0) == 0 - assert normalize_for_json(-100) == -100 - - def test_string_passthrough(self) -> None: - """Test string passes through unchanged.""" - assert normalize_for_json("hello") == "hello" - assert normalize_for_json("") == "" - - def test_float_passthrough(self) -> None: - """Test float passes through unchanged.""" - assert normalize_for_json(3.14) == 3.14 - assert normalize_for_json(0.0) == 0.0 - assert normalize_for_json(-1.5) == -1.5 - - def test_boolean_passthrough(self) -> None: - """Test boolean values pass through unchanged.""" - assert normalize_for_json(True) is True - assert normalize_for_json(False) is False - - def test_none_passthrough(self) -> None: - """Test None passes through unchanged.""" - assert normalize_for_json(None) is None - - -class TestNormalizeForJsonEnums: - """Test normalize_for_json with enum types.""" - - def test_enum_returns_value(self) -> None: - """Test enum is converted to its value.""" - assert normalize_for_json(Color.RED) == "red" - assert normalize_for_json(Color.GREEN) == "green" - assert normalize_for_json(Color.BLUE) == "blue" - - def test_integer_enum_returns_value(self) -> None: - """Test integer enum is converted to its value.""" - assert normalize_for_json(Status.ACTIVE) == 1 - assert normalize_for_json(Status.INACTIVE) == 0 - - def test_enum_in_dict(self) -> None: - """Test enum values are normalized in dicts.""" - result = normalize_for_json({"color": Color.RED, "status": Status.ACTIVE}) - assert result == {"color": "red", "status": 1} - - def test_enum_in_list(self) -> None: - """Test enum values are normalized in lists.""" - result = normalize_for_json([Color.RED, Color.GREEN, Status.ACTIVE]) - assert result == ["red", "green", 1] - - -class TestNormalizeForJsonBaseModel: - """Test normalize_for_json with Pydantic BaseModel.""" - - def test_simple_model_conversion(self) -> None: - """Test simple model is converted to dict.""" - model = SimpleModel(name="test", value=42) - result = normalize_for_json(model) - assert result == {"name": "test", "value": 42} - assert isinstance(result, dict) - - def test_model_with_enum_field(self) -> None: - """Test model with enum field normalizes enum values.""" - model = ModelWithEnum(status=Status.ACTIVE, color=Color.RED) - result = normalize_for_json(model) - assert result == {"status": 1, "color": "red"} - - def test_nested_model_conversion(self) -> None: - """Test nested models are recursively converted.""" - simple = SimpleModel(name="nested", value=10) - nested = NestedModel(simple=simple, color=Color.GREEN) - result = normalize_for_json(nested) - assert result == { - "simple": {"name": "nested", "value": 10}, - "color": "green", - } - - def test_model_with_list_field(self) -> None: - """Test model with list field is normalized.""" - model = ModelWithList(items=[1, 2, 3], tags=["a", "b"]) - result = normalize_for_json(model) - assert result == {"items": [1, 2, 3], "tags": ["a", "b"]} - - -class TestNormalizeForJsonCollections: - """Test normalize_for_json with collections.""" - - def test_empty_dict(self) -> None: - """Test empty dictionary passes through.""" - assert normalize_for_json({}) == {} - - def test_empty_list(self) -> None: - """Test empty list passes through.""" - assert normalize_for_json([]) == [] - - def test_empty_tuple(self) -> None: - """Test empty tuple passes through.""" - assert normalize_for_json(()) == () - - def test_dict_with_primitive_values(self) -> None: - """Test dict with primitive values.""" - data = {"a": 1, "b": "hello", "c": 3.14, "d": None} - assert normalize_for_json(data) == data - - def test_dict_with_mixed_values(self) -> None: - """Test dict with mixed nested types.""" - data = {"int": 42, "enum": Color.RED, "list": [1, 2, 3]} - result = normalize_for_json(data) - assert result == {"int": 42, "enum": "red", "list": [1, 2, 3]} - - def test_list_with_mixed_types(self) -> None: - """Test list with mixed types.""" - data = [1, "hello", 3.14, Color.RED, None] - result = normalize_for_json(data) - assert result == [1, "hello", 3.14, "red", None] - - def test_tuple_preserves_type(self) -> None: - """Test tuple values are normalized but type is preserved.""" - data = (1, "hello", Color.RED) - result = normalize_for_json(data) - assert result == (1, "hello", "red") - assert isinstance(result, tuple) - - def test_list_of_models(self) -> None: - """Test list of models is normalized.""" - models = [ - SimpleModel(name="m1", value=1), - SimpleModel(name="m2", value=2), - ] - result = normalize_for_json(models) - assert result == [{"name": "m1", "value": 1}, {"name": "m2", "value": 2}] - - -class TestNormalizeForJsonNested: - """Test normalize_for_json with deeply nested structures.""" - - def test_deeply_nested_dict(self) -> None: - """Test deeply nested dictionary structure.""" - data = {"a": {"b": {"c": {"d": {"e": Color.RED}}}}} - result = normalize_for_json(data) - assert result == {"a": {"b": {"c": {"d": {"e": "red"}}}}} - - def test_deeply_nested_list(self) -> None: - """Test deeply nested list structure.""" - data = [[[["value", Color.GREEN]]]] - result = normalize_for_json(data) - assert result == [[[["value", "green"]]]] - - def test_mixed_nested_collections(self) -> None: - """Test mixed nested dicts and lists.""" - data = { - "items": [ - {"id": 1, "color": Color.RED}, - {"id": 2, "color": Color.BLUE}, - ], - "status": Status.ACTIVE, - } - result = normalize_for_json(data) - assert result == { - "items": [ - {"id": 1, "color": "red"}, - {"id": 2, "color": "blue"}, - ], - "status": 1, - } - - def test_model_with_nested_list_of_models(self) -> None: - """Test model containing list of models.""" - - class Container(BaseModel): - items: list[SimpleModel] - - container = Container( - items=[ - SimpleModel(name="a", value=1), - SimpleModel(name="b", value=2), - ] - ) - result = normalize_for_json(container) - assert result == { - "items": [ - {"name": "a", "value": 1}, - {"name": "b", "value": 2}, - ] - } - - def test_complex_nested_structure(self) -> None: - """Test complex structure with models, enums, and collections.""" - model = ModelWithEnum(status=Status.ACTIVE, color=Color.RED) - data = { - "model": model, - "colors": [Color.RED, Color.GREEN, Color.BLUE], - "nested": { - "items": [1, 2, 3], - "status": Status.INACTIVE, - }, - } - result = normalize_for_json(data) - assert result == { - "model": {"status": 1, "color": "red"}, - "colors": ["red", "green", "blue"], - "nested": { - "items": [1, 2, 3], - "status": 0, - }, - } From 1a272a699ada3f5c4d85b59cc93e2c2d2c6e5264 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 22 Jan 2026 12:31:47 -0800 Subject: [PATCH 07/11] chore: update uv.lock --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index 658da713..d45f543a 100644 --- a/uv.lock +++ b/uv.lock @@ -2825,7 +2825,7 @@ wheels = [ [[package]] name = "tetra-rp" -version = "0.20.0" +version = "0.20.1" source = { editable = "." } dependencies = [ { name = "cloudpickle" }, From f3931e98e1d88cefe58e34c262a7c3814549f222 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 22 Jan 2026 13:23:52 -0800 Subject: [PATCH 08/11] fix: address PR 149 feedback comments - Fix worker_id access pattern in conftest.py with safer dict.get() approach - Use default_factory for mutable tuple in RetryConfig dataclass - Replace deprecated datetime.utcnow() with datetime.now(timezone.utc) --- src/tetra_rp/runtime/circuit_breaker.py | 20 +++++++++++--------- src/tetra_rp/runtime/reliability_config.py | 4 +++- tests/conftest.py | 7 +------ 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/tetra_rp/runtime/circuit_breaker.py b/src/tetra_rp/runtime/circuit_breaker.py index 8065fa43..c24d4f69 100644 --- a/src/tetra_rp/runtime/circuit_breaker.py +++ b/src/tetra_rp/runtime/circuit_breaker.py @@ -4,7 +4,7 @@ import logging import time from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from typing import Any, Callable, Optional @@ -28,7 +28,9 @@ class CircuitBreakerStats: success_count: int = 0 last_failure_at: Optional[datetime] = None last_success_at: Optional[datetime] = None - state_changed_at: datetime = field(default_factory=datetime.utcnow) + state_changed_at: datetime = field( + default_factory=lambda: datetime.now(timezone.utc) + ) total_requests: int = 0 total_failures: int = 0 @@ -104,7 +106,7 @@ async def _on_success(self) -> None: async with self._lock: self.stats.success_count += 1 self.stats.total_requests += 1 - self.stats.last_success_at = datetime.utcnow() + self.stats.last_success_at = datetime.now(timezone.utc) logger.debug( f"Circuit breaker {self.endpoint_url}: " @@ -126,7 +128,7 @@ async def _on_failure(self, error: Exception) -> None: self.stats.failure_count += 1 self.stats.total_failures += 1 self.stats.total_requests += 1 - self.stats.last_failure_at = datetime.utcnow() + self.stats.last_failure_at = datetime.now(timezone.utc) # Track failure times for sliding window now = time.time() @@ -155,7 +157,7 @@ def _transition_to_open(self) -> None: if self.stats.state == CircuitState.OPEN: return # Already open self.stats.state = CircuitState.OPEN - self.stats.state_changed_at = datetime.utcnow() + self.stats.state_changed_at = datetime.now(timezone.utc) self.stats.success_count = 0 logger.warning( f"Circuit breaker OPEN for {self.endpoint_url} " @@ -165,7 +167,7 @@ def _transition_to_open(self) -> None: def _transition_to_half_open(self) -> None: """Transition circuit to HALF_OPEN state.""" self.stats.state = CircuitState.HALF_OPEN - self.stats.state_changed_at = datetime.utcnow() + self.stats.state_changed_at = datetime.now(timezone.utc) self.stats.failure_count = 0 self.stats.success_count = 0 logger.info( @@ -175,7 +177,7 @@ def _transition_to_half_open(self) -> None: def _transition_to_closed(self) -> None: """Transition circuit to CLOSED state.""" self.stats.state = CircuitState.CLOSED - self.stats.state_changed_at = datetime.utcnow() + self.stats.state_changed_at = datetime.now(timezone.utc) self.stats.failure_count = 0 self.stats.success_count = 0 self._failure_times.clear() @@ -185,14 +187,14 @@ def _should_attempt_recovery(self) -> bool: """Check if enough time has passed to attempt recovery.""" if not self.stats.last_failure_at: return False - elapsed = datetime.utcnow() - self.stats.state_changed_at + elapsed = datetime.now(timezone.utc) - self.stats.state_changed_at return elapsed.total_seconds() >= self.timeout_seconds def _seconds_until_recovery(self) -> int: """Get seconds until recovery can be attempted.""" if not self.stats.state_changed_at: return self.timeout_seconds - elapsed = datetime.utcnow() - self.stats.state_changed_at + elapsed = datetime.now(timezone.utc) - self.stats.state_changed_at remaining = self.timeout_seconds - int(elapsed.total_seconds()) return max(0, remaining) diff --git a/src/tetra_rp/runtime/reliability_config.py b/src/tetra_rp/runtime/reliability_config.py index 2bd063e6..40748249 100644 --- a/src/tetra_rp/runtime/reliability_config.py +++ b/src/tetra_rp/runtime/reliability_config.py @@ -43,7 +43,9 @@ class RetryConfig: base_delay: float = 0.5 max_delay: float = 10.0 jitter: float = 0.2 - retryable_exceptions: tuple = field(default=(asyncio.TimeoutError, ConnectionError)) + retryable_exceptions: tuple = field( + default_factory=lambda: (asyncio.TimeoutError, ConnectionError) + ) retryable_status_codes: set = field( default_factory=lambda: {408, 429, 500, 502, 503, 504} ) diff --git a/tests/conftest.py b/tests/conftest.py index 8ced0595..7266b59d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -225,12 +225,7 @@ def worker_temp_dir( Returns: Path to worker-specific temporary directory. """ - # Try to get worker_id from xdist, default to 'master' if not available - worker_id = getattr(request.config, "workerinput", None) - if worker_id is None: - worker_id = "master" - else: - worker_id = worker_id["workerid"] + worker_id = getattr(request.config, "workerinput", {}).get("workerid", "master") if worker_id == "master": # Single worker (non-parallel) From 10ddab91c33509c69759e3b8f5ba103606eb4940 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Thu, 22 Jan 2026 14:12:32 -0800 Subject: [PATCH 09/11] fix: address new PR 149 feedback comments - Update docstring in conftest.py to clarify worker_id is extracted from request - Remove unused asyncio import from reliability_config.py - Use built-in TimeoutError instead of asyncio.TimeoutError (equivalent in Python 3.8+) --- src/tetra_rp/runtime/reliability_config.py | 3 +-- tests/conftest.py | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tetra_rp/runtime/reliability_config.py b/src/tetra_rp/runtime/reliability_config.py index 40748249..46bc6429 100644 --- a/src/tetra_rp/runtime/reliability_config.py +++ b/src/tetra_rp/runtime/reliability_config.py @@ -1,6 +1,5 @@ """Centralized configuration for reliability features.""" -import asyncio import os from dataclasses import dataclass, field from enum import Enum @@ -44,7 +43,7 @@ class RetryConfig: max_delay: float = 10.0 jitter: float = 0.2 retryable_exceptions: tuple = field( - default_factory=lambda: (asyncio.TimeoutError, ConnectionError) + default_factory=lambda: (TimeoutError, ConnectionError) ) retryable_status_codes: set = field( default_factory=lambda: {408, 429, 500, 502, 503, 504} diff --git a/tests/conftest.py b/tests/conftest.py index 7266b59d..13894819 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -216,11 +216,12 @@ def worker_temp_dir( """Provide worker-specific temporary directory for file system isolation. Each xdist worker gets its own isolated temp directory to prevent - file system conflicts when tests write to shared paths. + file system conflicts when tests write to shared paths. The worker ID + is extracted from the pytest-xdist request object. Args: tmp_path_factory: Pytest's temporary path factory. - request: Pytest request object. + request: Pytest request object used to extract worker ID. Returns: Path to worker-specific temporary directory. From 27f4841c286f7dad690d4626672d253b0dd4421a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Tue, 27 Jan 2026 11:18:26 -0800 Subject: [PATCH 10/11] docs: remove references to `FLASH_MOTHERSHIP_ID` --- docs/Deployment_Architecture.md | 1 - docs/Flash_Deploy_Guide.md | 13 +++---------- uv.lock | 2 +- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/docs/Deployment_Architecture.md b/docs/Deployment_Architecture.md index dd488efe..b1ffede0 100644 --- a/docs/Deployment_Architecture.md +++ b/docs/Deployment_Architecture.md @@ -301,6 +301,5 @@ Each reconciliation action updates State Manager: - `FLASH_MANIFEST_PATH` - Optional explicit path to manifest ### Child Endpoints -- `FLASH_MOTHERSHIP_ID` - Mothership's endpoint ID (constructs URL as `https://{FLASH_MOTHERSHIP_ID}.api.runpod.ai`) - `FLASH_RESOURCE_NAME` - Which resource config this endpoint represents - `RUNPOD_ENDPOINT_ID` - This child's endpoint ID diff --git a/docs/Flash_Deploy_Guide.md b/docs/Flash_Deploy_Guide.md index b09dab14..e6155868 100644 --- a/docs/Flash_Deploy_Guide.md +++ b/docs/Flash_Deploy_Guide.md @@ -397,7 +397,6 @@ sequenceDiagram **Key Environment Variables**: - `FLASH_RESOURCE_NAME`: This endpoint's resource config name (e.g., "gpu_config") -- `FLASH_MOTHERSHIP_ID`: Mothership endpoint ID for manifest queries - `FLASH_MANIFEST_PATH`: Optional override for manifest location - `RUNPOD_ENDPOINT_ID`: This endpoint's RunPod endpoint ID @@ -1025,11 +1024,6 @@ graph LR - Identifies which resource config this endpoint represents - Used by ServiceRegistry for local vs remote detection -**FLASH_MOTHERSHIP_ID** (Required for remote calls) -- Mothership endpoint ID -- Used to construct mothership URL for manifest queries -- Format: `{endpoint_id}` (constructed as `https://{FLASH_MOTHERSHIP_ID}.api.runpod.ai`) - **FLASH_MANIFEST_PATH** (Optional) - Override default manifest file location - If not set, searches: cwd, module dir, parent dirs @@ -1283,10 +1277,9 @@ logging.getLogger("tetra_rp.runtime.service_registry").setLevel(logging.DEBUG) **Cause**: ServiceRegistry unable to query mothership or manifest outdated **Solution**: -1. Verify `FLASH_MOTHERSHIP_ID` environment variable is set -2. Check mothership endpoint is running: `curl https://{mothership_id}.api.runpod.ai/manifest` -3. Verify manifest includes the resource config: `grep resource_name flash_manifest.json` -4. Check network connectivity between child and mothership endpoints +1. Check mothership endpoint is running: `curl https://{mothership_id}.api.runpod.ai/ping` +2. Verify manifest includes the resource config: `grep resource_name flash_manifest.json` +3. Check network connectivity between child and mothership endpoints ### Issue: Manifest cache staleness diff --git a/uv.lock b/uv.lock index d45f543a..666318ed 100644 --- a/uv.lock +++ b/uv.lock @@ -2825,7 +2825,7 @@ wheels = [ [[package]] name = "tetra-rp" -version = "0.20.1" +version = "0.21.0" source = { editable = "." } dependencies = [ { name = "cloudpickle" }, From e9f582c32dd59e8e1833b4aaa032102d355d3dba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Tue, 27 Jan 2026 13:45:17 -0800 Subject: [PATCH 11/11] refactor: remove obsolete FLASH_MOTHERSHIP_ID environment variable The FLASH_MOTHERSHIP_ID variable was originally intended for child endpoints to identify their parent mothership. However, the current peer-to-peer architecture uses RUNPOD_ENDPOINT_ID directly for State Manager queries, making this variable redundant. This variable was set but never consumed anywhere in the codebase, making it safe to remove without affecting functionality. Changes: - Remove FLASH_MOTHERSHIP_ID from environment dict in mothership_provisioner.py - Remove test assertion validating FLASH_MOTHERSHIP_ID presence - Update documentation references to remove mentions of FLASH_MOTHERSHIP_ID - All tests continue to pass --- docs/Cross_Endpoint_Routing.md | 3 +-- src/tetra_rp/runtime/mothership_provisioner.py | 1 - tests/unit/runtime/test_mothership_provisioner.py | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/Cross_Endpoint_Routing.md b/docs/Cross_Endpoint_Routing.md index aa83c194..82854d2a 100644 --- a/docs/Cross_Endpoint_Routing.md +++ b/docs/Cross_Endpoint_Routing.md @@ -463,7 +463,7 @@ class ServiceRegistry: manifest_path: Path to flash_manifest.json. Defaults to FLASH_MANIFEST_PATH env var or auto-detection. manifest_client: Manifest service client for mothership API. If None, - creates one from FLASH_MOTHERSHIP_ID env var. + creates one using State Manager API. cache_ttl: Manifest cache lifetime in seconds (default: 300). Environment Variables (for local vs remote detection): @@ -973,7 +973,6 @@ except Exception as e: ```python # Check environment variables import os -print(f"FLASH_MOTHERSHIP_ID: {os.getenv('FLASH_MOTHERSHIP_ID')}") print(f"RUNPOD_ENDPOINT_ID: {os.getenv('RUNPOD_ENDPOINT_ID')}") # Check state manager client directly diff --git a/src/tetra_rp/runtime/mothership_provisioner.py b/src/tetra_rp/runtime/mothership_provisioner.py index 8035bf25..1ae31849 100644 --- a/src/tetra_rp/runtime/mothership_provisioner.py +++ b/src/tetra_rp/runtime/mothership_provisioner.py @@ -270,7 +270,6 @@ def create_resource_from_manifest( env = { "FLASH_RESOURCE_NAME": resource_name, - "FLASH_MOTHERSHIP_ID": os.getenv("RUNPOD_ENDPOINT_ID"), } # Add "tmp-" prefix for test-mothership deployments diff --git a/tests/unit/runtime/test_mothership_provisioner.py b/tests/unit/runtime/test_mothership_provisioner.py index a4777eec..2a4e36d0 100644 --- a/tests/unit/runtime/test_mothership_provisioner.py +++ b/tests/unit/runtime/test_mothership_provisioner.py @@ -399,7 +399,6 @@ def test_create_resource_from_manifest_serverless(self): assert isinstance(resource, ServerlessResource) # ServerlessResource may append "-fb" suffix during initialization assert resource_name in resource.name - assert resource.env["FLASH_MOTHERSHIP_ID"] == "mothership-123" assert resource.env["FLASH_RESOURCE_NAME"] == resource_name def test_create_resource_from_manifest_live_serverless(self):