diff --git a/README.md b/README.md index 2c1f1eda..d442dbf4 100644 --- a/README.md +++ b/README.md @@ -340,6 +340,39 @@ results = await asyncio.gather( ) ``` +### Load-Balanced Endpoints with HTTP Routing + +For API endpoints requiring low-latency HTTP access with direct routing, use load-balanced endpoints: + +```python +from tetra_rp import LiveLoadBalancer, remote + +api = LiveLoadBalancer(name="api-service") + +@remote(api, method="POST", path="/api/process") +async def process_data(x: int, y: int): + return {"result": x + y} + +@remote(api, method="GET", path="/api/health") +def health_check(): + return {"status": "ok"} + +# Call functions directly +result = await process_data(5, 3) # → {"result": 8} +``` + +**Key differences from queue-based endpoints:** +- **Direct HTTP routing** - Requests routed directly to workers, no queue +- **Lower latency** - No queuing overhead +- **Custom HTTP methods** - GET, POST, PUT, DELETE, PATCH support +- **No automatic retries** - Users handle errors directly + +Load-balanced endpoints are ideal for REST APIs, webhooks, and real-time services. Queue-based endpoints are better for batch processing and fault-tolerant workflows. + +For detailed information: +- **User guide:** [Using @remote with Load-Balanced Endpoints](docs/Using_Remote_With_LoadBalancer.md) +- **Runtime architecture:** [LoadBalancer Runtime Architecture](docs/LoadBalancer_Runtime_Architecture.md) - details on deployment, request flows, and execution + ## How it works Flash orchestrates workflow execution through a sophisticated multi-step process: @@ -460,6 +493,8 @@ After `flash build` completes: For more details on the handler architecture, see [docs/Runtime_Generic_Handler.md](docs/Runtime_Generic_Handler.md). +For information on load-balanced endpoints (required for Mothership and HTTP services), see [docs/Load_Balancer_Endpoints.md](docs/Load_Balancer_Endpoints.md). + #### Troubleshooting Build Issues **No @remote functions found:** diff --git a/docs/LoadBalancer_Runtime_Architecture.md b/docs/LoadBalancer_Runtime_Architecture.md new file mode 100644 index 00000000..b55d80f2 --- /dev/null +++ b/docs/LoadBalancer_Runtime_Architecture.md @@ -0,0 +1,708 @@ +# Load-Balanced Endpoint Runtime Architecture + +## Overview + +This document explains what happens after a load-balanced endpoint is deployed on RunPod and is actively running. It covers the deployment architecture, request flows, and execution patterns for both direct HTTP requests and @remote function calls. + +## Deployment Architecture + +### Container Image and Startup + +When you deploy a `LoadBalancerSlsResource` endpoint with `flash build` and `flash deploy`: + +```mermaid +graph TD + A["User Code"] -->|flash build| B["Generate handler_service.py"] + B -->|FastAPI App| C["handler_service.py"] + C -->|flash deploy| D["Push to RunPod"] + D -->|Create Container| E["RunPod Container
tetra-rp-lb image"] + E --> F["FastAPI Server
uvicorn on port 8000"] + F --> G["Load your handler"] + G --> H["Endpoint Ready"] + + style A fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style B fill:#ff6b35,stroke:#c41e0f,stroke-width:3px,color:#fff + style C fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style D fill:#ff6b35,stroke:#c41e0f,stroke-width:3px,color:#fff + style E fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style F fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff + style G fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff + style H fill:#2e7d32,stroke:#1b5e20,stroke-width:3px,color:#fff +``` + +**Important:** `endpoint_url` is auto-generated by RunPod after deployment +- Cannot be specified by users +- Generated as: `https:///` +- Automatically populated in the resource after `deploy()` completes +- Available via `resource.endpoint_url` property (read-only) + +### What Gets Deployed + +The generated handler file contains: + +```python +# handler_service.py (auto-generated) +from fastapi import FastAPI +from tetra_rp.runtime.lb_handler import create_lb_handler + +# User functions imported +from api.endpoints import process_data +from api.health import health_check + +# Route registry +ROUTE_REGISTRY = { + ("POST", "/api/process"): process_data, + ("GET", "/api/health"): health_check, +} + +# FastAPI app created +app = create_lb_handler(ROUTE_REGISTRY) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) +``` + +**Container Setup:** +- Base image: `runpod/tetra-rp-lb:latest` (contains FastAPI, uvicorn, dependencies) +- Entrypoint: Runs `python handler_service.py` +- Port: 8000 (internal) +- RunPod exposes this via HTTPS endpoint URL +- Health check: Polls `/ping` endpoint every 30 seconds with 15 second timeout per check +- All HTTP requests to the endpoint include authentication via `RUNPOD_API_KEY` environment variable (if set) + +### Deployment Lifecycle + +```mermaid +graph TD + A["LoadBalancerSlsResource created"] -->|flash build| B["Generate handler file"] + B -->|flash deploy| C["Push to RunPod"] + C --> D["RunPod creates container"] + D --> E["Container starts uvicorn"] + E --> F["FastAPI app loads"] + F --> G["Import user functions"] + G --> H["Register routes"] + H --> I["Endpoint ready"] + I --> J["Health checks pass"] + J --> K["Endpoint active"] +``` + +## Request Flow + +### Direct HTTP Request (User Routes) + +When a client makes an HTTP request to your deployed endpoint: + +```mermaid +sequenceDiagram + participant Client + participant RunPod as RunPod Router + participant Container as Endpoint Container + participant FastAPI + participant UserFunc as User Function + + Client->>RunPod: HTTPS POST /api/process + RunPod->>Container: Forward to port 8000 + Container->>FastAPI: HTTP POST /api/process + FastAPI->>FastAPI: Match (POST, /api/process)
in ROUTE_REGISTRY + FastAPI->>UserFunc: Call process_data(x=5, y=3) + UserFunc->>UserFunc: Execute function code + UserFunc-->>FastAPI: Return {"result": 8} + FastAPI->>FastAPI: Serialize to JSON + FastAPI-->>Container: HTTP 200 response + Container-->>RunPod: Response body + RunPod-->>Client: HTTPS response +``` + +**Example Flow:** + +```python +# User code +@remote(api, method="POST", path="/api/process") +async def process_data(x: int, y: int): + return {"result": x + y} + +# Client request +POST https://my-endpoint.runpod.ai/api/process +Content-Type: application/json +{"x": 5, "y": 3} + +# On RunPod: +# 1. Request arrives at container port 8000 +# 2. FastAPI receives POST /api/process +# 3. FastAPI parses JSON body: {"x": 5, "y": 3} +# 4. FastAPI calls process_data(x=5, y=3) +# 5. Function executes: returns {"result": 8} +# 6. FastAPI serializes response +# 7. Returns HTTP 200 with body {"result": 8} +# 8. RunPod wraps in HTTPS response +# 9. Client receives response +``` + +### @remote Function Call (Framework Endpoint) + +When you call an `@remote` decorated function from your local code: + +```mermaid +sequenceDiagram + participant Local as Local Code + participant Stub as LoadBalancerSlsStub + participant RunPod as RunPod Router + participant Container as Endpoint Container + participant Execute as /execute Handler + + Local->>Stub: await process_data(5, 3) + Stub->>Stub: Extract function source code
via AST inspection + Stub->>Stub: Serialize args with cloudpickle
+ base64 encode + Stub->>RunPod: POST /execute + RunPod->>Container: Forward to port 8000 + Container->>Execute: HTTP POST /execute + Execute->>Execute: Parse JSON body + Execute->>Execute: Deserialize arguments
(base64 decode + cloudpickle loads) + Execute->>Execute: Extract function code string + Execute->>Execute: exec(code) in isolated namespace + Execute->>Execute: Call func(5, 3) + Execute->>Execute: Get result: {"result": 8} + Execute->>Execute: Serialize result with cloudpickle
+ base64 encode + Execute-->>Container: HTTP 200 {success: true, result: base64} + Container-->>RunPod: Response body + RunPod-->>Stub: Response body + Stub->>Stub: Deserialize result
(base64 decode + cloudpickle loads) + Stub-->>Local: Return {"result": 8} +``` + +**Example Flow:** + +```python +# Local code - after deployment +api = LoadBalancerSlsResource(name="user-service", + imageName="runpod/tetra-rp-lb:latest") + +# Deploy the endpoint (generates endpoint_url automatically) +await api.deploy() +# After deploy, api.endpoint_url is populated by RunPod +# Example: "https://xxx-yyy-zzz.runpod.io" + +@remote(api, method="POST", path="/api/process") +async def process_data(x: int, y: int): + return {"result": x + y} + +# Call the function locally +result = await process_data(5, 3) + +# What happens: +# 1. Decorator finds LoadBalancerSlsStub in registry +# 2. Stub extracts function source code via AST +# 3. Stub serializes arguments: cloudpickle.dumps([5, 3]) +# 4. Stub POST to https://my-endpoint.runpod.ai/execute +# 5. Container receives request at /execute endpoint +# 6. create_lb_handler's execute_remote_function handles it: +# a. Parses JSON body +# b. Deserializes arguments: [5, 3] +# c. Executes: exec(function_code) in isolated namespace +# d. Calls func(5, 3) +# e. Gets result: {"result": 8} +# f. Serializes result via cloudpickle +# g. Returns {success: true, result: base64_string} +# 7. Stub deserializes result +# 8. Returns {"result": 8} to caller +``` + +## Deployment Execution Model + +### Local Development (LiveLoadBalancer) + +When using `LiveLoadBalancer` for local testing, endpoints expose two types of routes: + +1. **User-Defined Routes** (e.g., `/api/health`, `/api/users`) + - Called via direct HTTP requests + - Called via `@remote` decorator (uses /execute internally) + +2. **Framework Endpoints** + - `/execute` - Accepts serialized function code for @remote execution + - `/ping` - Health check endpoint + +### Deployed Endpoints (LoadBalancerSlsResource) + +When deployed to production, endpoints **only expose user-defined routes** for security: + +1. **User-Defined Routes** (e.g., `/api/health`, `/api/users`) + - Called via direct HTTP requests from clients + - Called via `@remote` decorator (stub translates to HTTP requests to user routes) + - `/execute` endpoint NOT exposed (removed for security) + +2. **Framework Endpoints** + - `/ping` - Health check endpoint only + +### Request Handling by Execution Type + +#### Direct HTTP Requests (Always Works) + +``` +GET /health +POST /api/users +PUT /api/users/{user_id} +DELETE /api/users/{user_id} +``` + +**Characteristics:** +- Called by external HTTP clients +- FastAPI handles routing automatically +- Standard HTTP request/response +- No serialization/deserialization +- Direct function execution +- Errors return HTTP error codes + +**Example:** +```python +@remote(api, method="GET", path="/health") +def health_check(): + return {"status": "ok"} + +# Client can call: +GET https://my-endpoint.runpod.ai/health +# Response: 200 OK {"status": "ok"} +``` + +#### @remote Function Calls (Different Local vs Deployed) + +**Local (LiveLoadBalancer):** +```python +@remote(api, method="POST", path="/api/process") +async def process_data(x: int, y: int): + return {"result": x + y} + +# Called via @remote: +result = await process_data(5, 3) # Uses /execute internally (local only) +``` + +**Deployed (LoadBalancerSlsResource):** +```python +@remote(api, method="POST", path="/api/process") +async def process_data(x: int, y: int): + return {"result": x + y} + +# Called via @remote: +result = await process_data(5, 3) +# Stub automatically translates to: POST /api/process {"x": 5, "y": 3} +# No /execute endpoint involved (security) +``` + +**Key Differences:** +- Local: Serializes function code, POSTs to /execute +- Deployed: Maps arguments to JSON, POSTs to user-defined route +- No code changes needed - stub handles both automatically + +**Important Implementation Detail: Stub Decision Logic** + +The stub determines which execution path to use by checking: +1. Is this a `LiveLoadBalancer`? → Always use `/execute` for local development +2. Does the function have `method` and `path` metadata from `@remote` decorator? → If yes, use user-defined route +3. If routing metadata is incomplete or missing → Falls back to `/execute` (will fail on deployed endpoints) + +This means if you decorate a function for `LoadBalancerSlsResource` without specifying both `method` and `path`, the stub will attempt to use `/execute`, which doesn't exist in production. Always provide complete routing metadata for deployed endpoints. + +**Important Implementation Detail: Parameter Mapping** + +When using user-defined routes (deployed endpoints), the stub inspects the function signature and maps positional and keyword arguments to the HTTP request JSON body: + +```python +@remote(api, method="POST", path="/api/process") +async def process_data(x: int, y: int): + return {"result": x + y} + +# Local call: +result = await process_data(5, 3) + +# Gets translated to: +POST /api/process +{ + "x": 5, + "y": 3 +} +``` + +The stub uses Python's `inspect.signature()` to map positional args to parameter names. This requires that: +- Function parameters are JSON-serializable types (int, str, bool, list, dict, None) +- Function signature is available (defined at module level, not dynamically created) +- No complex types (custom classes, Request objects, etc.) are used as parameters + +## Execution Flow Diagram + +```mermaid +graph TD + A["HTTP Request arrives at
RunPod Endpoint"] -->|HTTPS| B["RunPod Router
Domain stripping"] + B -->|Strips domain
Forwards to container| C["Container Port 8000
uvicorn/FastAPI"] + + C -->|Route decision| D{Is it /execute?} + + D -->|Yes: Framework| E["Framework Handler
execute_remote_function"] + D -->|No: User Route| F["FastAPI Router
Match method + path in
ROUTE_REGISTRY"] + + E --> E1["1. Parse JSON body"] + E1 --> E2["2. Deserialize args/kwargs
base64 + cloudpickle"] + E2 --> E3["3. exec function_code
in isolated namespace"] + E3 --> E4["4. Call func with args"] + E4 --> E5["5. Serialize result
cloudpickle + base64"] + E5 --> G["Build Response
success: true/false"] + + F --> F1["1. Find handler function
from ROUTE_REGISTRY"] + F1 --> F2["2. Parse request parameters"] + F2 --> F3["3. Call function
with parameters"] + F3 --> F4["4. Get result"] + F4 --> G + + G -->|Serialize response| H["FastAPI Response Obj
JSON or {success, result}"] + H -->|Wrap in HTTPS| I["RunPod Router
Wraps response"] + I -->|Send back| J["HTTP Response to Client"] + + style A fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style B fill:#ff6b35,stroke:#c41e0f,stroke-width:3px,color:#fff + style C fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style D fill:#ff6b35,stroke:#c41e0f,stroke-width:3px,color:#fff + style E fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff + style F fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff + style E1 fill:#2e7d32,stroke:#1b5e20,stroke-width:2px,color:#fff + style E2 fill:#2e7d32,stroke:#1b5e20,stroke-width:2px,color:#fff + style E3 fill:#2e7d32,stroke:#1b5e20,stroke-width:2px,color:#fff + style E4 fill:#2e7d32,stroke:#1b5e20,stroke-width:2px,color:#fff + style E5 fill:#2e7d32,stroke:#1b5e20,stroke-width:2px,color:#fff + style F1 fill:#2e7d32,stroke:#1b5e20,stroke-width:2px,color:#fff + style F2 fill:#2e7d32,stroke:#1b5e20,stroke-width:2px,color:#fff + style F3 fill:#2e7d32,stroke:#1b5e20,stroke-width:2px,color:#fff + style F4 fill:#2e7d32,stroke:#1b5e20,stroke-width:2px,color:#fff + style G fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style H fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style I fill:#ff6b35,stroke:#c41e0f,stroke-width:3px,color:#fff + style J fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff +``` + +## Security Model at Runtime + +### /execute Endpoint + +The `/execute` endpoint is an internal framework endpoint that: + +1. **Accepts arbitrary Python code** (serialized as string) +2. **Executes it** in an isolated namespace using Python's `exec()` +3. **Returns results** back to caller + +**Critical Security Model:** + +The `/execute` endpoint is **only exposed on `LiveLoadBalancer` for local development**. It is **explicitly removed from deployed `LoadBalancerSlsResource` endpoints** for security reasons. + +**Why This Design Is Necessary:** + +The `/execute` endpoint accepts and executes arbitrary Python code sent in HTTP requests. An unauthorized user with access to this endpoint could: +- Execute system commands (e.g., `os.system()`) +- Access file system data (e.g., read environment variables, credentials) +- Modify application state or data +- Use your infrastructure for malicious purposes + +**Why This Is Secure When Used Correctly:** + +- In `LiveLoadBalancer` (local development): Code originates from your own `@remote` decorator +- You control what function code is serialized and sent +- Only accessible during local testing, never exposed publicly +- Same trusted-client model as queue-based serverless endpoints + +**What Happens When Deployed:** + +``` +LiveLoadBalancer (local): +- /execute endpoint: INCLUDED (for @remote function execution) +- User routes: Included +- Safe because: Only you can run your code locally + +LoadBalancerSlsResource (deployed): +- /execute endpoint: REMOVED for security +- User routes: Included +- Safe because: No arbitrary code execution possible +``` + +**If /execute Was Exposed (Don't Do This):** + +```python +# Attacker's request +POST https://my-endpoint.runpod.ai/execute +{ + "function_name": "malicious", + "function_code": "import os; os.system('rm -rf /')", + "args": [], + "kwargs": {} +} + +# This would execute arbitrary system commands on your infrastructure +``` + +**Best Practices:** + +- Never manually add `/execute` to deployed endpoints +- Use the default `create_lb_handler()` behavior (removes `/execute`) +- Always use `LoadBalancerSlsResource` for production (not `LiveLoadBalancer`) +- Test locally with `LiveLoadBalancer` first +- For debugging deployed endpoints, use container logs, not code injection + +## Concurrency and Scaling + +### How RunPod Handles Concurrent Requests + +```mermaid +graph TD + A["Request 1
POST /api/process"] -->|→ Worker 1| B["Container [Worker 1]
Executes Request 1"] + C["Request 2
POST /api/users"] -->|→ Worker 1| D["Queued in Worker 1"] + D -->|Worker available| E["Container [Worker 1]
Executes Request 2
Concurrently"] + F["Request 3
POST /api/health"] -->|→ Worker 2| G["Container [Worker 2]
Executes Request 3"] + + H["RunPod Scaler
REQUEST_COUNT"] -->|Queue grows| I["Monitor Queue Depth"] + I -->|Q ≥ 3| J["Spin up Worker 3"] + I -->|Q ≥ 6| K["Spin up Worker 4"] + I -->|Q empty| L["Wind down Workers"] + + style A fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style B fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff + style C fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style D fill:#ff6b35,stroke:#c41e0f,stroke-width:3px,color:#fff + style E fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff + style F fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style G fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff + style H fill:#ff6b35,stroke:#c41e0f,stroke-width:3px,color:#fff + style I fill:#ff6b35,stroke:#c41e0f,stroke-width:3px,color:#fff + style J fill:#2e7d32,stroke:#1b5e20,stroke-width:2px,color:#fff + style K fill:#2e7d32,stroke:#1b5e20,stroke-width:2px,color:#fff + style L fill:#2e7d32,stroke:#1b5e20,stroke-width:2px,color:#fff +``` + +### Function Execution + +- Each request executes in isolated context +- async functions execute with asyncio +- Multiple requests can process concurrently (with async) +- Synchronous functions block worker thread + +**Example Concurrency:** + +```python +@remote(api, method="POST", path="/api/process") +async def process_data(x: int): + import time + await asyncio.sleep(10) # Simulate work + return {"result": x} + +# If 5 requests come in simultaneously: +# - Request 1: await asyncio.sleep(10) → Worker 1 +# - Request 2: await asyncio.sleep(10) → Worker 1 (concurrent) +# - Request 3: await asyncio.sleep(10) → Worker 1 (concurrent) +# - Request 4: await asyncio.sleep(10) → Worker 2 (new worker) +# - Request 5: await asyncio.sleep(10) → Worker 2 (concurrent) +# +# All 5 complete in ~10s (concurrent within workers) +``` + +## Error Handling at Runtime + +### Client Errors + +``` +POST https://endpoint.runpod.ai/api/users +{"invalid": "json" + +# Response: 422 Unprocessable Entity +{ + "detail": [ + { + "type": "json_error", + "loc": ["body"], + "msg": "JSON decode error" + } + ] +} +``` + +### Function Errors + +``` +@remote(api, method="POST", path="/api/users") +async def create_user(name: str): + if not name: + raise ValueError("Name required") + return {"id": 1, "name": name} + +# Call with invalid data: +POST https://endpoint.runpod.ai/api/users +{"name": ""} + +# Response: 422 Validation Error or 500 Internal Error +# (depending on where error occurs) +``` + +### @remote Execution Errors + +```python +# Local code +@remote(api, method="POST", path="/api/process") +async def process_data(x: int): + raise RuntimeError("Processing failed") + +result = await process_data(5) +# Raises RuntimeError: "Remote execution failed: Processing failed" +``` + +## Performance Characteristics + +### Request Latency (approximate) + +``` +Direct HTTP Request: +- Request → RunPod Router: 10-50ms +- FastAPI routing: 1-5ms +- Function execution: Variable +- Serialization: Variable +- Response: 10-50ms +Total (no-op function): 30-110ms + +@remote Function Call: +- Function serialization: 1-10ms +- HTTP request to /execute: 10-50ms +- Deserialization: 1-10ms +- Function execution: Variable +- Result serialization: 1-10ms +- Result deserialization: 1-10ms +- Response: 10-50ms +Total (no-op function): 40-150ms +``` + +### Memory Usage + +- FastAPI app baseline: ~50-100MB +- Per function in namespace: ~0.5-5MB +- Serialized args/result: Variable (depends on data size) +- RunPod allocates: Depends on pod type + +### Request Size Limits + +- RunPod has limits on request body size +- Serialized data (via cloudpickle) increases size +- Large arguments may hit limits +- Consider streaming for large payloads + +## Monitoring and Debugging at Runtime + +### Logs Available on RunPod + +``` +Container logs (uvicorn/FastAPI): +- Request arrival +- Route matching +- Function execution +- Errors and exceptions +- Response generation + +Environment: +- Pod ID +- Worker ID +- GPU allocation +- Memory usage +``` + +### Health Checks + +``` +GET https://endpoint.runpod.ai/ping +Response: 200 OK {"status": "healthy"} + +RunPod polls /ping every 30 seconds +- 200 OK → Worker healthy +- Non-200 → Worker unhealthy +- No response → Worker down +- Unhealthy workers replaced +``` + +### Common Issues at Runtime + +**"Connection refused"** +- Container not running +- Uvicorn failed to start +- Check container logs + +**"Timeout after 30s"** +- Function took >30s +- Network issue +- Increase timeout if needed + +**"500 Internal Server Error"** +- Function raised exception +- Check container logs +- Verify function code + +## Deployment Considerations + +### Image Selection + +``` +tetra-rp-lb:latest (default) +- FastAPI + uvicorn pre-installed +- Tetra runtime dependencies +- Optimized for LB endpoints + +Custom image: +- Must have FastAPI, uvicorn +- Must expose port 8000 +- /ping endpoint should work +``` + +### Pod Configuration + +```python +LoadBalancerSlsResource( + name="my-api", + imageName="runpod/tetra-rp-lb:latest", + gpus=[GpuGroup.AMPERE_80], # Optional: if compute needed + instanceIds=[...], # Or specify CPU instances + workersMax=5, # Max concurrent workers + template=PodTemplate(...) # Storage, env vars, etc. +) +``` + +### Network + +``` +Incoming: +- HTTPS endpoint provided by RunPod +- Auto-scaled based on REQUEST_COUNT +- Health checks ensure availability + +Outgoing: +- Your functions can make HTTP requests +- Can access external APIs +- Can access other RunPod endpoints +``` + +## Summary + +**What Happens at Runtime:** + +1. **Deployment** - FastAPI app runs in RunPod container +2. **Request Arrival** - HTTP request reaches container +3. **Routing** - FastAPI matches method/path to function +4. **Execution** - Function code runs with parameters +5. **Response** - Result serialized and returned + +**Two Execution Paths:** + +- **User Routes** - Direct HTTP from clients +- **Framework Routes** - @remote calls from local code via /execute + +**Key Characteristics:** + +- ✅ Low latency (direct HTTP) +- ✅ No queuing overhead +- ✅ Concurrent request handling +- ✅ FastAPI routing +- ✅ Serialized function execution via @remote + +**Security:** + +- Protect `/execute` endpoint with authentication +- Only allow @remote calls from trusted sources +- Monitor endpoint usage diff --git a/docs/Load_Balancer_Endpoints.md b/docs/Load_Balancer_Endpoints.md new file mode 100644 index 00000000..ea551884 --- /dev/null +++ b/docs/Load_Balancer_Endpoints.md @@ -0,0 +1,406 @@ +# Load-Balanced Serverless Endpoints + +## Overview + +The `LoadBalancerSlsResource` class enables provisioning and management of RunPod load-balanced serverless endpoints. Unlike queue-based endpoints that process requests sequentially, load-balanced endpoints expose HTTP servers directly to clients, enabling REST APIs, webhooks, and real-time communication patterns. + +This resource type forms the foundation for the Mothership architecture, which requires HTTP-based endpoint discovery and cross-endpoint communication. + +## Design Context + +### Problem Statement + +RunPod supports two serverless endpoint models: + +1. **Queue-Based (QB)**: Sequential processing with automatic retry logic + - Requests queued and processed one-at-a-time + - Built-in error recovery + - Higher latency but fault-tolerant + - Fixed request/response format (JSON) + +2. **Load-Balanced (LB)**: Direct HTTP routing to healthy workers + - Requests routed directly to available workers + - No automatic retries + - Lower latency but less fault tolerance + - Custom HTTP endpoints and protocols + +### Design Decision + +Load-balanced endpoints require different provisioning and health check logic than queue-based endpoints. `LoadBalancerSlsResource` extends `ServerlessResource` with LB-specific behavior: + +- **Type enforcement**: Always deploys as LB (never QB) +- **Scaler validation**: Requires REQUEST_COUNT scaler (not QUEUE_DELAY) +- **Health checks**: Polls `/ping` endpoint to verify worker availability +- **Post-deployment verification**: Waits for endpoint readiness before returning + +### Why This Matters + +The Mothership needs to serve as a directory server for child endpoints. This requires: +- HTTP-based service discovery (not queue-based) +- Ability to expose custom endpoints (`/directory`, `/ping`) +- Health checking to verify children are ready before routing traffic + +## Architecture + +### High-Level Flow + +```mermaid +graph TD + A["LoadBalancerSlsResource
instance created"] --> B["Validate LB config
Type=LB, REQUEST_COUNT scaler"] + B --> C["Check if already
deployed"] + C -->|Already deployed| D["Return existing
endpoint"] + C -->|New deployment| E["Call parent _do_deploy
Create via RunPod API"] + E --> F["Poll /ping endpoint
until healthy"] + F -->|Health check fails| G["Raise TimeoutError
Deployment failed"] + F -->|Health check passes| H["Return deployed
endpoint"] + + style A fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style B fill:#ff6b35,stroke:#c41e0f,stroke-width:3px,color:#fff + style C fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style E fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff + style F fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff + style G fill:#c41e0f,stroke:#a41100,stroke-width:3px,color:#fff + style H fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff +``` + +### Configuration Hierarchy + +``` +ServerlessResource (base class) +├── type: ServerlessType = QB (queue-based) +├── scalerType: ServerlessScalerType = QUEUE_DELAY +├── Standard provisioning flow +└── Standard health checks (RunPod SDK) + +LoadBalancerSlsResource (LB-specific subclass) +├── type: ServerlessType = LB (always, cannot override) +├── scalerType: ServerlessScalerType = REQUEST_COUNT (required) +├── Enhanced provisioning flow +│ ├── Validation before deploy +│ └── Post-deployment health check polling +├── Async health check (_check_ping_endpoint) +├── Sync health check (is_deployed) +└── Health check polling (_wait_for_health) +``` + +### Health Check Mechanism + +Load-balanced endpoints require a `/ping` endpoint that responds with: +- **200 OK**: Worker is healthy and ready +- **204 No Content**: Worker is initializing (transient) +- **Other status**: Worker is unhealthy + +```mermaid +sequenceDiagram + participant Deploy as LoadBalancerSlsResource + participant RunPod as RunPod API + participant Worker as LB Endpoint + participant Ping as /ping Handler + + Deploy->>RunPod: saveEndpoint (type=LB) + RunPod->>Worker: Create endpoint + Worker->>Ping: Initialize + + loop Health Check Polling + Deploy->>Worker: GET /ping + alt Healthy + Worker->>Ping: Handle request + Ping->>Worker: Return 200 + Worker->>Deploy: 200 OK + Deploy->>Deploy: Deployment complete ✓ + else Initializing + Worker->>Ping: Still starting + Ping->>Worker: Return 204 + Worker->>Deploy: 204 No Content + Deploy->>Deploy: Wait and retry + else Unhealthy + Worker->>Worker: Error occurred + Worker->>Deploy: 500+ Error + Deploy->>Deploy: Retry with backoff + end + end +``` + +## Using @remote with LoadBalancer Endpoints + +This document focuses on the `LoadBalancerSlsResource` class implementation and architecture. + +**Related documentation:** +- [Using @remote with Load-Balanced Endpoints](Using_Remote_With_LoadBalancer.md) - User guide for writing and testing load-balanced endpoints +- [LoadBalancer Runtime Architecture](LoadBalancer_Runtime_Architecture.md) - Technical details on what happens when deployed on RunPod, request flows, and execution patterns + +**In the user guide, you'll learn:** +- Quick start with `LiveLoadBalancer` for local development +- HTTP routing with `method` and `path` parameters +- Building and deploying load-balanced endpoints +- Complete working examples +- Troubleshooting common issues + +**In the runtime architecture guide, you'll learn:** +- Deployment architecture and container setup +- Request flow for both direct HTTP and @remote calls +- Dual endpoint model (/execute vs user routes) +- Security considerations +- Performance characteristics and monitoring + +## Usage + +### Basic Provisioning + +```python +from tetra_rp import LoadBalancerSlsResource + +# Create a load-balanced endpoint +mothership = LoadBalancerSlsResource( + name="mothership", + imageName="my-mothership-app:latest", + workersMin=1, + workersMax=3, + env={ + "FLASH_APP": "my_app", + "LOG_LEVEL": "INFO", + } +) + +# Deploy endpoint +deployed = await mothership.deploy() + +# Endpoint is now deployed and healthy +print(f"Endpoint ID: {deployed.id}") +print(f"Endpoint URL: {deployed.endpoint_url}") +``` + +### Configuration Options + +```python +LoadBalancerSlsResource( + # Required fields + name="my-endpoint", + imageName="my-image:latest", + + # Worker scaling (for LB, these control max concurrent requests) + workersMin=1, # Min number of workers to keep warm + workersMax=5, # Max workers to spin up + scalerValue=10, # Target concurrent requests per worker + + # Environment configuration + env={ + "ENV_VAR": "value", + }, + + # Network and storage + networkVolume=NetworkVolume(...), # Optional: persistent storage + + # Deployment location + datacenter=DataCenter.EU_RO_1, # Or US_EAST_1, etc + + # Timeouts + executionTimeoutMs=600000, # 10 minute timeout + idleTimeout=5, # Seconds before scaling down +) +``` + +### Health Checks + +```python +# Synchronous health check (for compatibility with RunPod SDK) +is_healthy = endpoint.is_deployed() + +# Asynchronous health check (for deployment flow) +is_healthy = await endpoint.is_deployed_async() + +# Direct ping check (for debugging) +is_responding = await endpoint._check_ping_endpoint() + +# Health check polling with custom parameters +healthy = await endpoint._wait_for_health( + max_retries=20, + retry_interval=3, # seconds +) +``` + +## Validation and Error Handling + +### Configuration Validation + +The resource validates LB-specific constraints at creation and deployment time: + +```python +# This will fail at validation time +try: + bad_endpoint = LoadBalancerSlsResource( + name="test", + imageName="test:latest", + scalerType=ServerlessScalerType.QUEUE_DELAY, # Not allowed for LB! + ) + await bad_endpoint.deploy() +except ValueError as e: + # Error: LoadBalancerSlsResource requires REQUEST_COUNT scaler, + # not QUEUE_DELAY. Load-balanced endpoints don't support queue-based scaling. + print(f"Validation failed: {e}") +``` + +### Deployment Errors + +```python +try: + endpoint = LoadBalancerSlsResource( + name="mothership", + imageName="my-image:latest", + ) + deployed = await endpoint.deploy() +except TimeoutError as e: + # Health check failed after max retries + # Error: LB endpoint mothership (endpoint-id) failed to become + # healthy within 60s + print(f"Deployment failed: {e}") +except ValueError as e: + # RunPod API error or configuration issue + print(f"Deployment error: {e}") +``` + +## Type Safety + +`LoadBalancerSlsResource` enforces LB type at the class level: + +```python +# Type is always LB, cannot be changed +endpoint = LoadBalancerSlsResource( + name="test", + imageName="image", + type=ServerlessType.QB, # This gets overridden! +) + +assert endpoint.type == ServerlessType.LB # Always LB +``` + +## Performance Characteristics + +### Deployment Timeline + +| Phase | Duration | Notes | +|-------|----------|-------| +| API call | < 1s | RunPod endpoint creation | +| Worker initialization | 30-60s | Endpoint starts up | +| Health check polling | 5-50s | Depends on app startup time (10 retries × 5s = 50s max) | +| **Total** | **35-110s** | Typical: 60-90s | + +### Health Check Polling + +``` +Attempt 1: GET /ping → No response (endpoint starting) + Wait 5s +Attempt 2: GET /ping → 204 No Content (initializing) + Wait 5s +Attempt 3: GET /ping → 200 OK (healthy) ✓ + Deployment complete +``` + +Default configuration: +- Max retries: 10 +- Retry interval: 5 seconds +- Timeout per request: 5 seconds +- Total timeout: ~50 seconds + +## Comparison with Standard Endpoints + +| Feature | Queue-Based (QB) | Load-Balanced (LB) | +|---------|------------------|-------------------| +| Request model | Sequential queue | Direct HTTP routing | +| Retries | Automatic | Manual (client) | +| Latency | Higher (queuing) | Lower (direct) | +| Custom endpoints | Limited | Full HTTP support | +| Scalability | Per-function | Per-worker | +| Health checks | RunPod SDK | `/ping` endpoint | +| Use cases | Batch processing | APIs, webhooks, real-time | +| Suitable for | Workers | Mothership, services | + +## Implementation Details + +### Code Structure + +``` +LoadBalancerSlsResource (class) +├── __init__(...) +│ └── Enforce type=LB, scalerType=REQUEST_COUNT +├── _validate_lb_configuration() +│ └── Check scaler type, type field +├── is_deployed_async() +│ ├── Check endpoint ID +│ └── Call _check_ping_endpoint() +├── _check_ping_endpoint() +│ ├── GET /ping endpoint +│ └── Check status 200 or 204 +├── _wait_for_health(max_retries, retry_interval) +│ ├── Loop polling +│ ├── Exponential backoff +│ └── Return after success or timeout +├── _do_deploy() +│ ├── Call _validate_lb_configuration() +│ ├── Call parent _do_deploy() +│ ├── Call _wait_for_health() +│ └── Return deployed resource or raise TimeoutError +└── is_deployed() + └── Sync wrapper using RunPod SDK +``` + +### Thread Safety + +- `is_deployed()` is thread-safe (uses RunPod SDK) +- Async methods are safe for concurrent use +- Health check polling handles multiple concurrent calls + +## Troubleshooting + +### Health Check Timeout + +**Problem**: Deployment times out at health check step + +**Causes**: +- Endpoint failed to start (wrong image, runtime error) +- `/ping` endpoint not implemented +- `/ping` endpoint not responding within timeout +- Firewall/network blocking requests + +**Solution**: +- Verify image exists and runs correctly: `docker run my-image:latest` +- Implement `/ping` endpoint that returns 200 OK +- Check logs: `runpod-cli logs ` +- Increase timeout: `await endpoint._wait_for_health(max_retries=20)` + +### Configuration Validation Errors + +**Problem**: `ValueError: LoadBalancerSlsResource requires REQUEST_COUNT scaler` + +**Cause**: Scaler type set to QUEUE_DELAY + +**Solution**: +```python +# Remove scalerType specification (defaults to REQUEST_COUNT) +endpoint = LoadBalancerSlsResource( + name="test", + imageName="image", + # scalerType NOT specified, defaults to REQUEST_COUNT +) +``` + +### API Errors (401, 403, 429) + +**Problem**: RunPod GraphQL errors during deployment + +**Causes**: +- Missing or invalid RUNPOD_API_KEY +- Insufficient permissions +- Rate limiting + +**Solution**: +- Verify API key: `echo $RUNPOD_API_KEY` +- Check RunPod dashboard permissions +- Retry after delay for rate limits + +## Next Steps + +- **Mothership integration**: Use LoadBalancerSlsResource for Mothership endpoints +- **Service discovery**: Implement `/directory` endpoint for child endpoint discovery +- **Auto-provisioning**: Automatic child endpoint deployment on Mothership startup +- **Cross-endpoint routing**: Route requests between endpoints using service discovery diff --git a/docs/Using_Remote_With_LoadBalancer.md b/docs/Using_Remote_With_LoadBalancer.md new file mode 100644 index 00000000..952b805a --- /dev/null +++ b/docs/Using_Remote_With_LoadBalancer.md @@ -0,0 +1,548 @@ +# Using @remote with Load-Balanced Endpoints + +## Introduction + +Flash provides two ways to execute remote functions on serverless endpoints: queue-based (QB) and load-balanced (LB) endpoints. This guide covers using the `@remote` decorator with load-balanced endpoints for HTTP-based function execution. + +### Queue-Based vs Load-Balanced Endpoints + +**Queue-Based Endpoints** (ServerlessEndpoint, LiveServerless) +- Requests queued and processed sequentially +- Automatic retry logic on failure +- Built-in fault tolerance +- Higher latency (queuing + processing) +- Fixed request/response format + +**Load-Balanced Endpoints** (LoadBalancerSlsResource, LiveLoadBalancer) +- Requests routed directly to available workers +- Direct HTTP execution, no queue +- No automatic retries +- Lower latency (direct HTTP) +- Custom HTTP routes and methods + +### When to Use Each Type + +Use **Load-Balanced** when you need: +- Low latency API endpoints +- Custom HTTP routing (GET, POST, PUT, DELETE) +- Direct HTTP response handling +- Handling multiple routes on single endpoint + +Use **Queue-Based** when you need: +- Automatic retry logic on failures +- Sequential, fault-tolerant processing +- Tolerance for higher latency +- Simple request/response pattern + +## Quick Start + +### Basic Example with LiveLoadBalancer + +For local development, use `LiveLoadBalancer`: + +```python +from tetra_rp import LiveLoadBalancer, remote + +# Create load-balanced endpoint +api = LiveLoadBalancer(name="example-api") + +# Define HTTP-routed function +@remote(api, method="POST", path="/api/greet") +async def greet_user(name: str): + return {"message": f"Hello, {name}!"} + +# Call the function locally +async def main(): + result = await greet_user("Alice") + print(result) # {"message": "Hello, Alice!"} + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) +``` + +Key points: +- `method` parameter specifies HTTP method (GET, POST, PUT, DELETE, PATCH) +- `path` parameter specifies URL route (must start with `/`) +- Functions execute directly without deployment during development + +## HTTP Routing + +Load-balanced endpoints require explicit HTTP routing metadata in the `@remote` decorator. + +### Parameters + +**method** (required for LoadBalancerSlsResource) +- Must be one of: GET, POST, PUT, DELETE, PATCH +- Case-insensitive (POST, post, Post all work) + +**path** (required for LoadBalancerSlsResource) +- Must start with `/` (e.g., `/api/process`, `/health`) +- Can include path parameters (e.g., `/api/users/{user_id}`) +- Cannot use reserved paths: `/execute`, `/ping` + +### Single Endpoint with Multiple Routes + +Multiple functions can share a single LoadBalancerSlsResource with different routes: + +```python +from tetra_rp import LiveLoadBalancer, remote + +api = LiveLoadBalancer(name="user-service") + +@remote(api, method="GET", path="/users") +def list_users(): + return {"users": []} + +@remote(api, method="POST", path="/users") +async def create_user(name: str, email: str): + return {"id": 1, "name": name, "email": email} + +@remote(api, method="GET", path="/users/{user_id}") +def get_user(user_id: int): + return {"id": user_id, "name": "Alice"} + +@remote(api, method="DELETE", path="/users/{user_id}") +async def delete_user(user_id: int): + return {"deleted": True} +``` + +When deployed: +- Single `user-service` endpoint created +- Four HTTP routes registered automatically +- FastAPI handles routing to correct function + +### Reserved Paths + +The following paths are reserved by Flash and cannot be used as user-defined routes: + +- `/ping` - Health check endpoint (required, returns 200 OK) + +**Important Security Note:** +- `/execute` - Framework endpoint for @remote stub execution (**only available with LiveLoadBalancer for local development**) + - Deployed `LoadBalancerSlsResource` endpoints **deliberately do NOT expose `/execute`** for security reasons + - The `/execute` endpoint accepts and executes arbitrary Python code - exposing it would allow remote code execution + - When using deployed endpoints, @remote calls are safely translated to HTTP requests to your user-defined routes + - Never manually add `/execute` to deployed endpoints + +Attempting to use `/ping` or `/execute` as user-defined routes will raise a validation error at build time. + +## Local Development + +### Using LiveLoadBalancer + +For local development and testing, use `LiveLoadBalancer` instead of `LoadBalancerSlsResource`: + +```python +from tetra_rp import LiveLoadBalancer, remote + +api = LiveLoadBalancer(name="my-api") + +@remote(api, method="POST", path="/api/process") +async def process_data(x: int, y: int): + return {"result": x + y} + +# In tests or scripts, call directly +async def test(): + result = await process_data(5, 3) + assert result == {"result": 8} +``` + +**Key differences:** +- `LiveLoadBalancer` locks image to Tetra LB runtime (tetra-rp-lb) +- Functions execute directly without deployment +- Ideal for development and CI/CD testing +- Same `@remote` decorator interface as production + +### Testing Patterns + +```python +import pytest +from tetra_rp import LiveLoadBalancer, remote + +api = LiveLoadBalancer(name="test-api") + +@remote(api, method="POST", path="/api/calculate") +async def calculate(operation: str, a: int, b: int): + if operation == "add": + return a + b + elif operation == "multiply": + return a * b + else: + raise ValueError(f"Unknown operation: {operation}") + +@pytest.mark.asyncio +async def test_calculate_add(): + result = await calculate("add", 5, 3) + assert result == 8 + +@pytest.mark.asyncio +async def test_calculate_multiply(): + result = await calculate("multiply", 5, 3) + assert result == 15 + +@pytest.mark.asyncio +async def test_calculate_invalid(): + with pytest.raises(ValueError): + await calculate("unknown", 5, 3) +``` + +## Local vs Deployed Execution + +The behavior of `@remote` decorated functions differs between local development and deployed endpoints: + +### Local Development (LiveLoadBalancer) + +When using `LiveLoadBalancer` for local testing: +- Functions decorated with `@remote` serialize their code and POST to `/execute` endpoint +- The `/execute` endpoint accepts and executes the serialized function code +- Useful for development and CI/CD testing before deployment + +### Deployed Endpoints (LoadBalancerSlsResource) + +When deployed to production: +- Generated handlers do NOT expose `/execute` endpoint (security) +- Functions decorated with `@remote` are called via HTTP requests to their user-defined routes +- The stub automatically translates `@remote` calls into HTTP requests with mapped parameters +- Example: `await process_data(5, 3)` becomes `POST /api/process {"x": 5, "y": 3}` + +### Migration from Local to Deployed + +When migrating code from local testing to production: +- **No code changes needed** - `@remote` decorated functions work the same way +- The stub automatically detects whether it's `LiveLoadBalancer` (local) or `LoadBalancerSlsResource` (deployed) +- User-defined routes must be compatible with JSON serialization for parameters + +**Parameter Type Constraints on Deployed Endpoints:** + +When using deployed `LoadBalancerSlsResource` endpoints, function parameters are serialized to JSON in the HTTP request body. This means: + +**Supported types:** +- Primitive types: `int`, `str`, `bool`, `float` +- Collections: `list`, `dict`, `tuple`, `set` +- Nested structures: `list[dict[str, int]]`, etc. +- Optional types: `Optional[str]`, `Optional[int]` +- Special: `None` + +**Unsupported types:** +- Custom classes and dataclasses +- Request objects (FastAPI Request, Starlette Request) +- File/binary objects +- Complex Python objects that can't serialize to JSON +- Datetime objects (without custom serialization) + +**Example of parameter mapping:** + +```python +# Local call: +result = await process_data(5, "hello", [1, 2, 3]) + +# Gets translated to deployed endpoint call: +POST /api/process +{ + "x": 5, + "name": "hello", + "items": [1, 2, 3] +} +``` + +If you need to use complex types (e.g., File uploads, custom objects), use direct HTTP calls instead of the `@remote` decorator for deployed endpoints. For local development with `LiveLoadBalancer`, complex types work because the entire function is serialized and executed. + +## Building and Deploying + +### Build Process + +When you run `flash build`, the system: + +1. **Scans** your code for `@remote` decorated functions +2. **Extracts** HTTP routing metadata (method, path) +3. **Generates** FastAPI application with routes +4. **Creates** one handler file per LoadBalancerSlsResource +5. **Validates** routes for conflicts and reserved paths + +Example generated handler: + +```python +from fastapi import FastAPI +from tetra_rp.runtime.lb_handler import create_lb_handler + +# Imported from user code +from api.endpoints import process_data, health_check + +# Route registry built automatically +ROUTE_REGISTRY = { + ("POST", "/api/process"): process_data, + ("GET", "/api/health"): health_check, +} + +# FastAPI app created with routes +app = create_lb_handler(ROUTE_REGISTRY) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) +``` + +### Deployment Workflow + +```bash +# 1. Define functions with @remote decorator in your code +# 2. Test locally with LiveLoadBalancer +# 3. Build for production +flash build + +# 4. Configure your endpoint (optional) +# Edit flash.toml if needed to set image, GPU, etc. + +# 5. Deploy +flash deploy + +# 6. Check deployment status +flash status +``` + +### Verifying Deployment + +Once deployed, verify your endpoint: + +```bash +# Check endpoint is healthy +curl https:///ping +# Expected response: {"status": "healthy"} + +# Call your function via HTTP +curl -X POST https:///api/process \ + -H "Content-Type: application/json" \ + -d '{"x": 5, "y": 3}' +``` + +## Complete Working Example + +Here's a full example with multiple routes, error handling, and testing: + +```python +""" +user_service.py - Example load-balanced API service +""" + +from tetra_rp import LoadBalancerSlsResource, remote +from typing import Optional + +# For production, use LoadBalancerSlsResource +# For local development, use LiveLoadBalancer +api = LoadBalancerSlsResource( + name="user-service", + imageName="runpod/tetra-rp-lb:latest" +) + +class UserNotFound(Exception): + pass + +# In-memory database for example +users_db = { + 1: {"id": 1, "name": "Alice", "email": "alice@example.com"}, + 2: {"id": 2, "name": "Bob", "email": "bob@example.com"}, +} + +@remote(api, method="GET", path="/health") +def health_check(): + """Health check endpoint.""" + return {"status": "healthy"} + +@remote(api, method="GET", path="/users") +def list_users(): + """List all users.""" + return {"users": list(users_db.values())} + +@remote(api, method="POST", path="/users") +async def create_user(name: str, email: str): + """Create a new user.""" + user_id = max(users_db.keys() or [0]) + 1 + user = {"id": user_id, "name": name, "email": email} + users_db[user_id] = user + return user + +@remote(api, method="GET", path="/users/{user_id}") +def get_user(user_id: int): + """Get a specific user.""" + if user_id not in users_db: + raise UserNotFound(f"User {user_id} not found") + return users_db[user_id] + +@remote(api, method="PUT", path="/users/{user_id}") +async def update_user(user_id: int, name: Optional[str] = None, + email: Optional[str] = None): + """Update a user.""" + if user_id not in users_db: + raise UserNotFound(f"User {user_id} not found") + + user = users_db[user_id] + if name is not None: + user["name"] = name + if email is not None: + user["email"] = email + return user + +@remote(api, method="DELETE", path="/users/{user_id}") +async def delete_user(user_id: int): + """Delete a user.""" + if user_id not in users_db: + raise UserNotFound(f"User {user_id} not found") + + del users_db[user_id] + return {"deleted": True} +``` + +### Testing the Example + +```python +""" +test_user_service.py +""" + +import pytest +from tetra_rp import LiveLoadBalancer, remote +from typing import Optional + +# Use LiveLoadBalancer for testing +api = LiveLoadBalancer(name="user-service-test") + +# Define functions (same as above but use test endpoint) +# ... (function definitions) ... + +@pytest.mark.asyncio +async def test_list_users(): + users = list_users() + assert "users" in users + assert isinstance(users["users"], list) + +@pytest.mark.asyncio +async def test_create_and_get_user(): + # Create a user + new_user = await create_user("Charlie", "charlie@example.com") + assert new_user["name"] == "Charlie" + assert new_user["id"] > 0 + + # Get the user + user = get_user(new_user["id"]) + assert user["name"] == "Charlie" + +@pytest.mark.asyncio +async def test_update_user(): + new_user = await create_user("Diana", "diana@example.com") + updated = await update_user(new_user["id"], name="Diana Updated") + assert updated["name"] == "Diana Updated" + +@pytest.mark.asyncio +async def test_delete_user(): + new_user = await create_user("Eve", "eve@example.com") + result = await delete_user(new_user["id"]) + assert result["deleted"] is True + + # Should raise error when trying to get deleted user + with pytest.raises(Exception): # UserNotFound + get_user(new_user["id"]) +``` + +## Troubleshooting + +### Validation Errors + +**"requires both 'method' and 'path'"** +- Problem: Using `@remote(lb_resource)` without method/path +- Solution: Add both parameters: `@remote(lb, method="POST", path="/api/endpoint")` + +**"Invalid HTTP method 'PATCH' must be one of: GET, POST, PUT, DELETE, PATCH"** +- Problem: Typo in HTTP method (e.g., `PTACH` instead of `PATCH`) +- Solution: Verify method spelling matches valid HTTP verbs + +**"path must start with '/'"** +- Problem: Path doesn't start with forward slash +- Solution: Use absolute paths: `/api/endpoint` not `api/endpoint` + +**"Route conflict detected: POST /api/process defined twice"** +- Problem: Two functions with same method and path on same endpoint +- Solution: Change path or method to make each route unique + +### Runtime Errors + +**"Endpoint URL not available - endpoint may not be deployed"** +- Problem: Using LoadBalancerSlsResource before calling `await resource.deploy()` +- Solution: Deploy the endpoint first (`await resource.deploy()`) which auto-populates endpoint_url, or use LiveLoadBalancer for local testing +- Note: endpoint_url is auto-generated by RunPod after deployment and cannot be manually specified + +**"HTTP error from endpoint: 500"** +- Problem: Function raised an error during execution +- Solution: Check function code for exceptions, view endpoint logs + +**"Execution timeout on user-service after 30s"** +- Problem: Function took longer than 30 seconds to complete +- Solution: Optimize function, consider increasing timeout in LoadBalancerSlsStub + +**"404 Not Found" or "404 error" when calling @remote on deployed endpoint** +- Problem: Function decorated with @remote but missing `method` and/or `path` parameters +- Solution: Always provide complete routing metadata: `@remote(api, method="POST", path="/api/endpoint")` +- Note: On `LoadBalancerSlsResource`, the stub will try to use the non-existent `/execute` endpoint if routing metadata is missing + +**"JSON serialization error" or "unexpected keyword argument" on deployed endpoint** +- Problem: Deployed endpoint receiving malformed parameters from @remote call +- Solution: This should not happen automatically (stub handles parameter mapping). Check: + - Function parameters are JSON-serializable (not custom classes or Request objects) + - Function signature matches expected parameter names + - For complex types, make direct HTTP calls instead of using @remote + +### Build Errors + +**"Cannot import module 'user_service'"** +- Problem: Function module not found during handler generation +- Solution: Ensure module is in Python path, check import statements + +**"Function 'process_data' not found in executed code"** +- Problem: Function source extraction failed +- Solution: Ensure function is defined at module level (not inside another function) + +## API Reference + +### @remote Decorator with LoadBalancerSlsResource + +```python +@remote( + resource_config: LoadBalancerSlsResource | LiveLoadBalancer, + method: str = None, # Required: GET, POST, PUT, DELETE, PATCH + path: str = None, # Required: /api/route + dependencies: List[str] = None, # Python packages to install + system_dependencies: List[str] = None, # System packages to install + accelerate_downloads: bool = True # Use download acceleration +) +def your_function(...): + pass +``` + +### LoadBalancerSlsResource + +See `docs/Load_Balancer_Endpoints.md` for detailed architecture and configuration options. + +### LiveLoadBalancer + +A test/development variant of LoadBalancerSlsResource: +- Locks to Tetra LB image +- Enables direct function calls without deployment +- Same decorator interface as production + +## Best Practices + +1. **Use LiveLoadBalancer for testing** - No deployment needed for development +2. **Test locally before deploying** - Catch routing/logic errors early +3. **Use descriptive paths** - `/api/users/{user_id}` is clearer than `/api/u` +4. **Group related routes** - Keep similar endpoints on same service +5. **Handle errors gracefully** - Return meaningful error messages to clients +6. **Verify health checks** - Ensure `/ping` endpoint works after deployment +7. **Document your API** - Add docstrings explaining what each route does + +## Next Steps + +- Review `docs/Load_Balancer_Endpoints.md` for LoadBalancerSlsResource class architecture +- Review `docs/LoadBalancer_Runtime_Architecture.md` for runtime execution and request flows +- Check examples in `flash-examples/` repository for more patterns +- Use `flash build --help` to see build options +- Use `flash run --help` to see local testing options diff --git a/src/tetra_rp/__init__.py b/src/tetra_rp/__init__.py index f7f21130..e72b89d7 100644 --- a/src/tetra_rp/__init__.py +++ b/src/tetra_rp/__init__.py @@ -15,12 +15,16 @@ from .client import remote from .core.resources import ( CpuInstanceType, + CpuLiveLoadBalancer, CpuLiveServerless, + CpuLoadBalancerSlsResource, CpuServerlessEndpoint, CudaVersion, DataCenter, GpuGroup, + LiveLoadBalancer, LiveServerless, + LoadBalancerSlsResource, NetworkVolume, PodTemplate, ResourceManager, @@ -36,47 +40,59 @@ def __getattr__(name): return remote elif name in ( - "CpuServerlessEndpoint", "CpuInstanceType", + "CpuLiveLoadBalancer", "CpuLiveServerless", + "CpuLoadBalancerSlsResource", + "CpuServerlessEndpoint", "CudaVersion", "DataCenter", "GpuGroup", + "LiveLoadBalancer", "LiveServerless", + "LoadBalancerSlsResource", + "NetworkVolume", "PodTemplate", "ResourceManager", "ServerlessEndpoint", "ServerlessType", - "NetworkVolume", ): from .core.resources import ( - CpuServerlessEndpoint, CpuInstanceType, + CpuLiveLoadBalancer, CpuLiveServerless, + CpuLoadBalancerSlsResource, + CpuServerlessEndpoint, CudaVersion, DataCenter, GpuGroup, + LiveLoadBalancer, LiveServerless, + LoadBalancerSlsResource, + NetworkVolume, PodTemplate, ResourceManager, ServerlessEndpoint, ServerlessType, - NetworkVolume, ) attrs = { - "CpuServerlessEndpoint": CpuServerlessEndpoint, "CpuInstanceType": CpuInstanceType, + "CpuLiveLoadBalancer": CpuLiveLoadBalancer, "CpuLiveServerless": CpuLiveServerless, + "CpuLoadBalancerSlsResource": CpuLoadBalancerSlsResource, + "CpuServerlessEndpoint": CpuServerlessEndpoint, "CudaVersion": CudaVersion, "DataCenter": DataCenter, "GpuGroup": GpuGroup, + "LiveLoadBalancer": LiveLoadBalancer, "LiveServerless": LiveServerless, + "LoadBalancerSlsResource": LoadBalancerSlsResource, + "NetworkVolume": NetworkVolume, "PodTemplate": PodTemplate, "ResourceManager": ResourceManager, "ServerlessEndpoint": ServerlessEndpoint, "ServerlessType": ServerlessType, - "NetworkVolume": NetworkVolume, } return attrs[name] raise AttributeError(f"module {__name__!r} has no attribute {name!r}") @@ -84,16 +100,20 @@ def __getattr__(name): __all__ = [ "remote", - "CpuServerlessEndpoint", "CpuInstanceType", + "CpuLiveLoadBalancer", "CpuLiveServerless", + "CpuLoadBalancerSlsResource", + "CpuServerlessEndpoint", "CudaVersion", "DataCenter", "GpuGroup", + "LiveLoadBalancer", "LiveServerless", + "LoadBalancerSlsResource", + "NetworkVolume", "PodTemplate", "ResourceManager", "ServerlessEndpoint", "ServerlessType", - "NetworkVolume", ] diff --git a/src/tetra_rp/cli/commands/build.py b/src/tetra_rp/cli/commands/build.py index b8d909d2..e4a53075 100644 --- a/src/tetra_rp/cli/commands/build.py +++ b/src/tetra_rp/cli/commands/build.py @@ -17,6 +17,7 @@ from ..utils.ignore import get_file_tree, load_ignore_patterns from .build_utils.handler_generator import HandlerGenerator +from .build_utils.lb_handler_generator import LBHandlerGenerator from .build_utils.manifest import ManifestBuilder from .build_utils.scanner import RemoteDecoratorScanner @@ -115,9 +116,30 @@ def build_command( manifest_path = build_dir / "flash_manifest.json" manifest_path.write_text(json.dumps(manifest, indent=2)) - # Generate handler files - handler_gen = HandlerGenerator(manifest, build_dir) - handler_paths = handler_gen.generate_handlers() + # Generate handler files based on resource type + handler_paths = [] + + # Separate resources by type + lb_resources = { + name: data + for name, data in manifest.get("resources", {}).items() + if data.get("resource_type") == "LoadBalancerSlsResource" + } + qb_resources = { + name: data + for name, data in manifest.get("resources", {}).items() + if data.get("resource_type") != "LoadBalancerSlsResource" + } + + # Generate LB handlers + if lb_resources: + lb_gen = LBHandlerGenerator(manifest, build_dir) + handler_paths.extend(lb_gen.generate_handlers()) + + # Generate QB handlers + if qb_resources: + qb_gen = HandlerGenerator(manifest, build_dir) + handler_paths.extend(qb_gen.generate_handlers()) progress.update( manifest_task, diff --git a/src/tetra_rp/cli/commands/build_utils/handler_generator.py b/src/tetra_rp/cli/commands/build_utils/handler_generator.py index f019d2d2..c1e5d616 100644 --- a/src/tetra_rp/cli/commands/build_utils/handler_generator.py +++ b/src/tetra_rp/cli/commands/build_utils/handler_generator.py @@ -43,10 +43,14 @@ def __init__(self, manifest: Dict[str, Any], build_dir: Path): self.build_dir = build_dir def generate_handlers(self) -> List[Path]: - """Generate all handler files.""" + """Generate all handler files for queue-based (non-LB) resources.""" handler_paths = [] for resource_name, resource_data in self.manifest.get("resources", {}).items(): + # Skip load-balanced resources (handled by LBHandlerGenerator) + if resource_data.get("resource_type") == "LoadBalancerSlsResource": + continue + handler_path = self._generate_handler(resource_name, resource_data) handler_paths.append(handler_path) diff --git a/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py b/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py new file mode 100644 index 00000000..c7bf0ba7 --- /dev/null +++ b/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py @@ -0,0 +1,200 @@ +"""Generator for FastAPI handlers for LoadBalancerSlsResource endpoints.""" + +import importlib.util +import logging +from pathlib import Path +from typing import Any, Dict, List + +logger = logging.getLogger(__name__) + +LB_HANDLER_TEMPLATE = '''""" +Auto-generated FastAPI handler for LoadBalancerSlsResource: {resource_name} +Generated at: {timestamp} + +This file is generated by the Flash build process. Do not edit manually. + +Load-balanced endpoints expose HTTP servers directly to clients, enabling: +- REST APIs with custom HTTP routing +- WebSocket servers +- Real-time communication patterns +""" + +from fastapi import FastAPI, Request +from tetra_rp.runtime.lb_handler import create_lb_handler + +# Import all functions/classes that belong to this resource +{imports} + +# Route registry: (method, path) -> function +ROUTE_REGISTRY = {{ +{registry} +}} + +# Create FastAPI app with routes +# Note: include_execute={include_execute} for this endpoint type +# - LiveLoadBalancer (local): include_execute=True for /execute endpoint +# - LoadBalancerSlsResource (deployed): include_execute=False (security) +app = create_lb_handler(ROUTE_REGISTRY, include_execute={include_execute}) + + +# Health check endpoint (required for RunPod load-balancer endpoints) +@app.get("/ping") +def ping(): + """Health check endpoint for RunPod load-balancer. + + Returns: + dict: Status response + """ + return {{"status": "healthy"}} + + +if __name__ == "__main__": + import uvicorn + # Local development server for testing + uvicorn.run(app, host="0.0.0.0", port=8000) +''' + + +class LBHandlerGenerator: + """Generates FastAPI handlers for LoadBalancerSlsResource endpoints.""" + + def __init__(self, manifest: Dict[str, Any], build_dir: Path): + self.manifest = manifest + self.build_dir = build_dir + + def generate_handlers(self) -> List[Path]: + """Generate all LB handler files.""" + handler_paths = [] + + for resource_name, resource_data in self.manifest.get("resources", {}).items(): + # Generate for both LiveLoadBalancer (local dev) and LoadBalancerSlsResource (deployed) + resource_type = resource_data.get("resource_type") + if resource_type not in ["LoadBalancerSlsResource", "LiveLoadBalancer"]: + continue + + handler_path = self._generate_handler(resource_name, resource_data) + handler_paths.append(handler_path) + + return handler_paths + + def _generate_handler( + self, resource_name: str, resource_data: Dict[str, Any] + ) -> Path: + """Generate a single FastAPI handler file.""" + handler_filename = f"handler_{resource_name}.py" + handler_path = self.build_dir / handler_filename + + # Get timestamp from manifest + timestamp = self.manifest.get("generated_at", "") + + # Determine if /execute endpoint should be included + # LiveLoadBalancer (local dev) includes /execute, deployed LoadBalancerSlsResource does not + resource_type = resource_data.get("resource_type", "LoadBalancerSlsResource") + include_execute = resource_type == "LiveLoadBalancer" + + # Generate imports section + imports = self._generate_imports(resource_data.get("functions", [])) + + # Generate route registry + registry = self._generate_route_registry(resource_data.get("functions", [])) + + # Format template + handler_code = LB_HANDLER_TEMPLATE.format( + resource_name=resource_name, + timestamp=timestamp, + imports=imports, + registry=registry, + include_execute=str(include_execute), + ) + + handler_path.write_text(handler_code) + + # Validate that generated handler can be imported + self._validate_handler_imports(handler_path) + + return handler_path + + def _generate_imports(self, functions: List[Dict[str, Any]]) -> str: + """Generate import statements for functions. + + Args: + functions: List of function metadata dicts + + Returns: + Import statements as string + """ + imports = [] + + for func in functions: + module = func.get("module") + name = func.get("name") + + if module and name: + imports.append(f"from {module} import {name}") + + return "\n".join(imports) if imports else "# No functions to import" + + def _generate_route_registry(self, functions: List[Dict[str, Any]]) -> str: + """Generate route registry for FastAPI app. + + Creates mapping of (method, path) tuples to function names. + + Args: + functions: List of function metadata dicts with http_method and http_path + + Returns: + Registry dictionary as string + """ + if not functions: + return " # No functions registered" + + registry_lines = [] + + for func in functions: + name = func.get("name") + method = func.get("http_method") + path = func.get("http_path") + + if name and method and path: + # Create tuple key: ("GET", "/api/process") + registry_lines.append(f' ("{method}", "{path}"): {name},') + elif name: + # Skip if method or path missing (shouldn't happen with validation) + logger.warning( + f"Function '{name}' missing http_method or http_path. Skipping." + ) + + return ( + "\n".join(registry_lines) + if registry_lines + else " # No routes registered" + ) + + def _validate_handler_imports(self, handler_path: Path) -> None: + """Validate that generated handler has valid Python syntax. + + Attempts to load the handler module to catch syntax errors. + ImportErrors for missing worker modules are logged but not fatal, + as those imports may not be available at build time. + + Args: + handler_path: Path to generated handler file + + Raises: + ValueError: If handler has syntax errors or cannot be parsed + """ + try: + spec = importlib.util.spec_from_file_location("handler", handler_path) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + else: + raise ValueError("Failed to create module spec") + except SyntaxError as e: + raise ValueError(f"Handler has syntax errors: {e}") from e + except ImportError as e: + # Log but don't fail - imports might not be available at build time + logger.debug(f"Handler import validation: {e}") + except Exception as e: + # Only raise for truly unexpected errors + logger.warning(f"Handler validation warning: {e}") diff --git a/src/tetra_rp/cli/commands/build_utils/manifest.py b/src/tetra_rp/cli/commands/build_utils/manifest.py index 6df594d6..2664150f 100644 --- a/src/tetra_rp/cli/commands/build_utils/manifest.py +++ b/src/tetra_rp/cli/commands/build_utils/manifest.py @@ -2,12 +2,14 @@ import json from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from .scanner import RemoteFunctionMetadata +RESERVED_PATHS = ["/execute", "/ping"] + @dataclass class ManifestFunction: @@ -17,6 +19,8 @@ class ManifestFunction: module: str is_async: bool is_class: bool + http_method: Optional[str] = None # HTTP method for LB endpoints (GET, POST, etc.) + http_path: Optional[str] = None # HTTP path for LB endpoints (/api/process) @dataclass @@ -50,31 +54,75 @@ def build(self) -> Dict[str, Any]: # Build manifest structure resources_dict: Dict[str, Dict[str, Any]] = {} function_registry: Dict[str, str] = {} + routes_dict: Dict[ + str, Dict[str, str] + ] = {} # resource_name -> {route_key -> function_name} for resource_name, functions in sorted(resources.items()): handler_file = f"handler_{resource_name}.py" + # Use actual resource type from first function in group + resource_type = ( + functions[0].resource_type if functions else "LiveServerless" + ) + + # Validate and collect routing for LB endpoints + resource_routes = {} + is_load_balanced = resource_type in [ + "LoadBalancerSlsResource", + "LiveLoadBalancer", + ] + if is_load_balanced: + for f in functions: + if not f.http_method or not f.http_path: + raise ValueError( + f"{resource_type} endpoint '{resource_name}' requires " + f"method and path for function '{f.function_name}'. " + f"Got method={f.http_method}, path={f.http_path}" + ) + + # Check for route conflicts (same method + path) + route_key = f"{f.http_method} {f.http_path}" + if route_key in resource_routes: + raise ValueError( + f"Duplicate route '{route_key}' in resource '{resource_name}': " + f"both '{resource_routes[route_key]}' and '{f.function_name}' " + f"are mapped to the same route" + ) + resource_routes[route_key] = f.function_name + + # Check for reserved paths + if f.http_path in RESERVED_PATHS: + raise ValueError( + f"Function '{f.function_name}' cannot use reserved path '{f.http_path}'. " + f"Reserved paths: {', '.join(RESERVED_PATHS)}" + ) + functions_list = [ { "name": f.function_name, "module": f.module_path, "is_async": f.is_async, "is_class": f.is_class, + **( + {"http_method": f.http_method, "http_path": f.http_path} + if is_load_balanced + else {} + ), } for f in functions ] - # Use actual resource type from first function in group - resource_type = ( - functions[0].resource_type if functions else "LiveServerless" - ) - resources_dict[resource_name] = { "resource_type": resource_type, "handler_file": handler_file, "functions": functions_list, } + # Store routes for LB endpoints + if resource_routes: + routes_dict[resource_name] = resource_routes + # Build function registry for quick lookup for f in functions: if f.function_name in function_registry: @@ -84,14 +132,22 @@ def build(self) -> Dict[str, Any]: ) function_registry[f.function_name] = resource_name - return { + manifest = { "version": "1.0", - "generated_at": datetime.utcnow().isoformat() + "Z", + "generated_at": datetime.now(timezone.utc) + .isoformat() + .replace("+00:00", "Z"), "project_name": self.project_name, "resources": resources_dict, "function_registry": function_registry, } + # Add routes section if there are LB endpoints with routing + if routes_dict: + manifest["routes"] = routes_dict + + return manifest + def write_to_file(self, output_path: Path) -> Path: """Write manifest to file.""" manifest = self.build() diff --git a/src/tetra_rp/cli/commands/build_utils/scanner.py b/src/tetra_rp/cli/commands/build_utils/scanner.py index b6f1ecf4..92f80fc2 100644 --- a/src/tetra_rp/cli/commands/build_utils/scanner.py +++ b/src/tetra_rp/cli/commands/build_utils/scanner.py @@ -22,6 +22,8 @@ class RemoteFunctionMetadata: is_async: bool is_class: bool file_path: Path + http_method: Optional[str] = None # HTTP method for LB endpoints: GET, POST, etc. + http_path: Optional[str] = None # HTTP path for LB endpoints: /api/process class RemoteDecoratorScanner: @@ -86,7 +88,7 @@ def _extract_resource_configs(self, tree: ast.AST, py_file: Path) -> None: for node in ast.walk(tree): if isinstance(node, ast.Assign): - # Look for assignments like: gpu_config = LiveServerless(...) + # Look for assignments like: gpu_config = LiveServerless(...) or api = LiveLoadBalancer(...) for target in node.targets: if isinstance(target, ast.Name): variable_name = target.id @@ -134,6 +136,11 @@ def _extract_remote_functions( # Get resource type for this config resource_type = self._get_resource_type(resource_config_name) + # Extract HTTP routing metadata (for LB endpoints) + http_method, http_path = self._extract_http_routing( + remote_decorator + ) + metadata = RemoteFunctionMetadata( function_name=node.name, module_path=module_path, @@ -142,6 +149,8 @@ def _extract_remote_functions( is_async=is_async, is_class=is_class, file_path=py_file, + http_method=http_method, + http_path=http_path, ) functions.append(metadata) @@ -299,3 +308,40 @@ def _get_module_path(self, py_file: Path) -> str: except ValueError: # If relative_to fails, just use filename return py_file.stem + + def _extract_http_routing( + self, decorator: ast.expr + ) -> tuple[Optional[str], Optional[str]]: + """Extract HTTP method and path from @remote decorator. + + Returns: + Tuple of (method, path) or (None, None) if not found. + method: GET, POST, PUT, DELETE, PATCH + path: /api/endpoint routes + + Raises: + ValueError: If method is not a valid HTTP verb + """ + if not isinstance(decorator, ast.Call): + return None, None + + http_method = None + http_path = None + + # Extract keyword arguments: method="POST", path="/api/process" + for keyword in decorator.keywords: + if keyword.arg == "method": + if isinstance(keyword.value, ast.Constant): + http_method = keyword.value.value + elif keyword.arg == "path": + if isinstance(keyword.value, ast.Constant): + http_path = keyword.value.value + + # Validate HTTP method if provided + valid_methods = {"GET", "POST", "PUT", "DELETE", "PATCH"} + if http_method is not None and http_method.upper() not in valid_methods: + raise ValueError( + f"Invalid HTTP method '{http_method}'. Must be one of: {', '.join(valid_methods)}" + ) + + return http_method, http_path diff --git a/src/tetra_rp/client.py b/src/tetra_rp/client.py index 0fa1826f..9dc8a019 100644 --- a/src/tetra_rp/client.py +++ b/src/tetra_rp/client.py @@ -4,7 +4,7 @@ from functools import wraps from typing import List, Optional -from .core.resources import ResourceManager, ServerlessResource +from .core.resources import LoadBalancerSlsResource, ResourceManager, ServerlessResource from .execute_class import create_remote_class from .stubs import stub_resource @@ -17,6 +17,8 @@ def remote( system_dependencies: Optional[List[str]] = None, accelerate_downloads: bool = True, local: bool = False, + method: Optional[str] = None, + path: Optional[str] = None, **extra, ): """ @@ -44,6 +46,12 @@ def remote( local (bool, optional): Execute function/class locally instead of provisioning remote servers. Returns the unwrapped function/class for direct local execution. Users must ensure all required dependencies are already installed in their local environment. Defaults to False. + method (str, optional): HTTP method for load-balanced endpoints (LoadBalancerSlsResource). + Required for LoadBalancerSlsResource: "GET", "POST", "PUT", "DELETE", "PATCH". + Ignored for queue-based endpoints. Defaults to None. + path (str, optional): HTTP path for load-balanced endpoints (LoadBalancerSlsResource). + Required for LoadBalancerSlsResource. Must start with "/". Example: "/api/process". + Ignored for queue-based endpoints. Defaults to None. extra (dict, optional): Additional parameters for the execution of the resource. Defaults to an empty dict. Returns: @@ -52,9 +60,9 @@ def remote( Example: ```python - # Async function (recommended style) + # Queue-based endpoint (recommended for reliability) @remote( - resource_config=my_resource_config, + resource_config=LiveServerless(name="gpu_worker"), dependencies=["torch>=2.0.0"], ) async def gpu_task(data: dict) -> dict: @@ -62,20 +70,19 @@ async def gpu_task(data: dict) -> dict: # GPU processing here return {"result": "processed"} - # Sync function (also supported) + # Load-balanced endpoint (for low-latency APIs) @remote( - resource_config=my_resource_config, - dependencies=["pandas>=2.0.0"], + resource_config=LoadBalancerSlsResource(name="api-service"), + method="POST", + path="/api/process", ) - def cpu_task(data: dict) -> dict: - import pandas as pd - # CPU processing here - return {"result": "processed"} + async def api_endpoint(x: int, y: int) -> dict: + return {"result": x + y} # Local execution (testing/development) @remote( resource_config=my_resource_config, - dependencies=["numpy", "pandas"], # Only used for remote execution + dependencies=["numpy", "pandas"], local=True, ) async def my_test_function(data): @@ -85,18 +92,53 @@ async def my_test_function(data): """ def decorator(func_or_class): + # Validate HTTP routing parameters for LoadBalancerSlsResource + is_lb_resource = isinstance(resource_config, LoadBalancerSlsResource) + + if is_lb_resource: + if not method or not path: + raise ValueError( + f"LoadBalancerSlsResource requires both 'method' and 'path' parameters. " + f"Got method={method}, path={path}. " + f"Example: @remote(resource_config, method='POST', path='/api/process')" + ) + if not path.startswith("/"): + raise ValueError(f"path must start with '/'. Got: {path}") + valid_methods = {"GET", "POST", "PUT", "DELETE", "PATCH"} + if method not in valid_methods: + raise ValueError( + f"method must be one of {valid_methods}. Got: {method}" + ) + elif method or path: + log.warning( + f"HTTP routing parameters (method={method}, path={path}) are only used " + f"with LoadBalancerSlsResource, but resource_config is {type(resource_config).__name__}. " + f"They will be ignored." + ) + + # Store routing metadata for scanner and build system + routing_config = { + "resource_config": resource_config, + "method": method, + "path": path, + "dependencies": dependencies, + "system_dependencies": system_dependencies, + } + if os.getenv("RUNPOD_POD_ID") or os.getenv("RUNPOD_ENDPOINT_ID"): # Worker mode when running on RunPod platform + func_or_class.__remote_config__ = routing_config return func_or_class # Local execution mode - execute without provisioning remote servers if local: + func_or_class.__remote_config__ = routing_config return func_or_class # Remote execution mode if inspect.isclass(func_or_class): # Handle class decoration - return create_remote_class( + wrapped_class = create_remote_class( func_or_class, resource_config, dependencies, @@ -104,6 +146,8 @@ def decorator(func_or_class): accelerate_downloads, extra, ) + wrapped_class.__remote_config__ = routing_config + return wrapped_class else: # Handle function decoration @wraps(func_or_class) @@ -123,6 +167,8 @@ async def wrapper(*args, **kwargs): **kwargs, ) + # Store routing metadata on wrapper for scanner + wrapper.__remote_config__ = routing_config return wrapper return decorator diff --git a/src/tetra_rp/core/resources/__init__.py b/src/tetra_rp/core/resources/__init__.py index 60ab6b71..b47b50d9 100644 --- a/src/tetra_rp/core/resources/__init__.py +++ b/src/tetra_rp/core/resources/__init__.py @@ -2,23 +2,35 @@ from .cpu import CpuInstanceType from .gpu import GpuGroup, GpuType, GpuTypeDetail from .resource_manager import ResourceManager -from .live_serverless import LiveServerless, CpuLiveServerless +from .live_serverless import ( + CpuLiveLoadBalancer, + CpuLiveServerless, + LiveLoadBalancer, + LiveServerless, +) from .serverless import ( ServerlessResource, ServerlessEndpoint, JobOutput, CudaVersion, ServerlessType, + ServerlessScalerType, ) from .serverless_cpu import CpuServerlessEndpoint from .template import PodTemplate from .network_volume import NetworkVolume, DataCenter +from .load_balancer_sls_resource import ( + CpuLoadBalancerSlsResource, + LoadBalancerSlsResource, +) __all__ = [ "BaseResource", "CpuInstanceType", + "CpuLiveLoadBalancer", "CpuLiveServerless", + "CpuLoadBalancerSlsResource", "CpuServerlessEndpoint", "CudaVersion", "DataCenter", @@ -27,11 +39,14 @@ "GpuType", "GpuTypeDetail", "JobOutput", + "LiveLoadBalancer", "LiveServerless", + "LoadBalancerSlsResource", + "NetworkVolume", + "PodTemplate", "ResourceManager", - "ServerlessResource", "ServerlessEndpoint", + "ServerlessResource", + "ServerlessScalerType", "ServerlessType", - "PodTemplate", - "NetworkVolume", ] diff --git a/src/tetra_rp/core/resources/live_serverless.py b/src/tetra_rp/core/resources/live_serverless.py index 193810b0..7064189b 100644 --- a/src/tetra_rp/core/resources/live_serverless.py +++ b/src/tetra_rp/core/resources/live_serverless.py @@ -1,6 +1,12 @@ # Ship serverless code as you write it. No builds, no deploys — just run. import os + from pydantic import model_validator + +from .load_balancer_sls_resource import ( + CpuLoadBalancerSlsResource, + LoadBalancerSlsResource, +) from .serverless import ServerlessEndpoint from .serverless_cpu import CpuServerlessEndpoint @@ -11,6 +17,12 @@ TETRA_CPU_IMAGE = os.environ.get( "TETRA_CPU_IMAGE", f"runpod/tetra-rp-cpu:{TETRA_IMAGE_TAG}" ) +TETRA_LB_IMAGE = os.environ.get( + "TETRA_LB_IMAGE", f"runpod/tetra-rp-lb:{TETRA_IMAGE_TAG}" +) +TETRA_CPU_LB_IMAGE = os.environ.get( + "TETRA_CPU_LB_IMAGE", f"runpod/tetra-rp-lb-cpu:{TETRA_IMAGE_TAG}" +) class LiveServerlessMixin: @@ -60,3 +72,98 @@ def set_live_serverless_template(cls, data: dict): """Set default CPU image for Live Serverless.""" data["imageName"] = TETRA_CPU_IMAGE return data + + +class LiveLoadBalancer(LiveServerlessMixin, LoadBalancerSlsResource): + """Live load-balanced endpoint for local development and testing. + + Similar to LiveServerless but for HTTP-based load-balanced endpoints. + Enables local testing of @remote decorated functions with LB endpoints + before deploying to production. + + Features: + - Locks to Tetra LB image (tetra-rp-lb) + - Direct HTTP execution (not queue-based) + - Local development with flash run + - Same @remote decorator pattern as LoadBalancerSlsResource + + Usage: + from tetra_rp import LiveLoadBalancer, remote + + api = LiveLoadBalancer(name="api-service") + + @remote(api, method="POST", path="/api/process") + async def process_data(x: int, y: int): + return {"result": x + y} + + # Test locally + result = await process_data(5, 3) + + Local Development Flow: + 1. Create LiveLoadBalancer with routing + 2. Decorate functions with @remote(lb_resource, method=..., path=...) + 3. Run with `flash run` to start local endpoint + 4. Call functions directly in tests or scripts + 5. Deploy to production with `flash build` and `flash deploy` + + Note: + The endpoint_url is configured by the Flash runtime when the + endpoint is deployed locally. For true local testing without + deployment, use the functions directly or mock the HTTP layer. + """ + + @property + def _live_image(self) -> str: + return TETRA_LB_IMAGE + + @model_validator(mode="before") + @classmethod + def set_live_lb_template(cls, data: dict): + """Set default image for Live Load-Balanced endpoint.""" + data["imageName"] = TETRA_LB_IMAGE + return data + + +class CpuLiveLoadBalancer(LiveServerlessMixin, CpuLoadBalancerSlsResource): + """CPU-only live load-balanced endpoint for local development and testing. + + Similar to LiveLoadBalancer but configured for CPU instances with + automatic disk sizing and validation. + + Features: + - Locks to CPU Tetra LB image (tetra-rp-lb-cpu) + - CPU instance support with automatic disk sizing + - Direct HTTP execution (not queue-based) + - Local development with flash run + - Same @remote decorator pattern as CpuLoadBalancerSlsResource + + Usage: + from tetra_rp import CpuLiveLoadBalancer, remote + + api = CpuLiveLoadBalancer(name="api-service") + + @remote(api, method="POST", path="/api/process") + async def process_data(x: int, y: int): + return {"result": x + y} + + # Test locally + result = await process_data(5, 3) + + Local Development Flow: + 1. Create CpuLiveLoadBalancer with routing + 2. Decorate functions with @remote(lb_resource, method=..., path=...) + 3. Run with `flash run` to start local endpoint + 4. Call functions directly in tests or scripts + 5. Deploy to production with `flash build` and `flash deploy` + """ + + @property + def _live_image(self) -> str: + return TETRA_CPU_LB_IMAGE + + @model_validator(mode="before") + @classmethod + def set_live_cpu_lb_template(cls, data: dict): + """Set default CPU image for Live Load-Balanced endpoint.""" + data["imageName"] = TETRA_CPU_LB_IMAGE + return data diff --git a/src/tetra_rp/core/resources/load_balancer_sls_resource.py b/src/tetra_rp/core/resources/load_balancer_sls_resource.py new file mode 100644 index 00000000..1ea5085b --- /dev/null +++ b/src/tetra_rp/core/resources/load_balancer_sls_resource.py @@ -0,0 +1,374 @@ +""" +LoadBalancerSlsResource - Resource type for RunPod Load-Balanced Serverless endpoints. + +Load-balanced endpoints expose HTTP servers directly to clients without the queue-based +processing model of standard serverless endpoints. They're ideal for REST APIs, webhooks, +and real-time communication patterns. + +Key differences from standard serverless (QB): +- Requests route directly to healthy workers via HTTP +- No automatic retries (client responsible) +- Lower latency but less fault tolerance +- Requires HTTP application, not a function handler +- Health checks via /ping endpoint +""" + +import asyncio +import logging +from typing import List, Optional + +from pydantic import model_validator + +from tetra_rp.core.utils.http import get_authenticated_httpx_client +from .cpu import CpuInstanceType +from .serverless import ServerlessResource, ServerlessType, ServerlessScalerType +from .serverless_cpu import CpuEndpointMixin + +log = logging.getLogger(__name__) + +# Configuration constants +DEFAULT_HEALTH_CHECK_RETRIES = 10 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 # seconds between retries +DEFAULT_PING_REQUEST_TIMEOUT = ( + 15.0 # seconds (load-balanced workers need time for cold starts) +) +HEALTHY_STATUS_CODES = (200, 204) + + +class LoadBalancerSlsResource(ServerlessResource): + """ + Resource configuration for RunPod Load-Balanced Serverless endpoints. + + Load-balanced endpoints expose HTTP servers directly, making them suitable for: + - REST APIs + - WebSocket servers + - Real-time streaming + - Custom HTTP protocols + + Configuration example: + mothership = LoadBalancerSlsResource( + name="mothership", + imageName="my-mothership:latest", + env={"FLASH_APP": "my_app"}, + workersMin=1, + workersMax=3, + ) + await mothership.deploy() + """ + + # Override default type to LB + type: Optional[ServerlessType] = ServerlessType.LB + + def __init__(self, **data): + """Initialize LoadBalancerSlsResource with LB-specific defaults.""" + # Ensure type is always LB + data["type"] = ServerlessType.LB + + # LB endpoints shouldn't use queue-based scaling + if "scalerType" not in data: + data["scalerType"] = ServerlessScalerType.REQUEST_COUNT + + super().__init__(**data) + + @model_validator(mode="after") + def set_serverless_template(self): + """Create template from imageName if not provided. + + Must run after sync_input_fields to ensure all input fields are synced. + """ + if not any([self.imageName, self.template, self.templateId]): + raise ValueError( + "Either imageName, template, or templateId must be provided" + ) + + if not self.templateId and not self.template: + self.template = self._create_new_template() + elif self.template: + self._configure_existing_template() + + return self + + @property + def endpoint_url(self) -> str: + """Get the endpoint URL for load-balanced endpoints. + + Load-balanced endpoints use a different URL format than standard + serverless endpoints. They use: https://{endpoint_id}.api.runpod.ai + + Returns: + The endpoint URL for health checks and direct HTTP requests + + Raises: + ValueError: If endpoint ID not set + """ + if not self.id: + raise ValueError("Endpoint ID not set. Cannot determine endpoint URL.") + return f"https://{self.id}.api.runpod.ai" + + def _validate_lb_configuration(self) -> None: + """ + Validate LB-specific configuration constraints. + + Raises: + ValueError: If configuration violates LB requirements + """ + # LB must use REQUEST_COUNT scaler, not QUEUE_DELAY + if self.scalerType == ServerlessScalerType.QUEUE_DELAY: + raise ValueError( + f"LoadBalancerSlsResource requires REQUEST_COUNT scaler, " + f"not {self.scalerType.value}. " + "Load-balanced endpoints don't support queue-based scaling." + ) + + # Type must always be LB + if self.type != ServerlessType.LB: + raise ValueError( + f"LoadBalancerSlsResource type must be LB, got {self.type.value}" + ) + + async def is_deployed_async(self) -> bool: + """ + Check if LB endpoint is deployed and /ping endpoint is responding. + + For LB endpoints, we verify: + 1. Endpoint ID exists (created in RunPod) + 2. /ping endpoint returns 200 or 204 + 3. Endpoint is in healthy state + + Returns: + True if endpoint is deployed and healthy, False otherwise + """ + try: + if not self.id: + return False + + # Use async health check for LB endpoints + return await self._check_ping_endpoint() + + except Exception as e: + log.debug(f"Error checking {self}: {e}") + return False + + async def _check_ping_endpoint(self) -> bool: + """ + Check if /ping endpoint is accessible and healthy. + + RunPod load-balancer endpoints require a /ping endpoint that returns: + - 200 OK: Worker is healthy and ready + - 204 No Content: Worker is initializing + - Other status: Worker is unhealthy + + Returns: + True if /ping endpoint responds with 200 or 204 + """ + try: + if not self.id: + return False + + ping_url = f"{self.endpoint_url}/ping" + + async with get_authenticated_httpx_client( + timeout=DEFAULT_PING_REQUEST_TIMEOUT + ) as client: + response = await client.get(ping_url) + return response.status_code in HEALTHY_STATUS_CODES + except Exception as e: + log.debug(f"Ping check failed for {self.name}: {e}") + return False + + async def _wait_for_health( + self, + max_retries: int = DEFAULT_HEALTH_CHECK_RETRIES, + retry_interval: int = DEFAULT_HEALTH_CHECK_INTERVAL, + ) -> bool: + """ + Poll /ping endpoint until endpoint is healthy or timeout. + + Args: + max_retries: Number of health check attempts + retry_interval: Seconds between health check attempts + + Returns: + True if endpoint became healthy, False if timeout + + Raises: + ValueError: If endpoint ID not set + """ + if not self.id: + raise ValueError("Cannot wait for health: endpoint not deployed") + + log.info( + f"Waiting for LB endpoint {self.name} ({self.id}) to become healthy... " + f"(max {max_retries} retries, {retry_interval}s interval)" + ) + + for attempt in range(max_retries): + try: + if await self._check_ping_endpoint(): + log.info( + f"LB endpoint {self.name} is healthy (attempt {attempt + 1})" + ) + return True + + log.debug( + f"Health check attempt {attempt + 1}/{max_retries} - " + f"endpoint not ready yet" + ) + + except Exception as e: + log.debug(f"Health check attempt {attempt + 1} failed: {e}") + + # Wait before next attempt (except on last attempt) + if attempt < max_retries - 1: + await asyncio.sleep(retry_interval) + + log.error( + f"LB endpoint {self.name} failed to become healthy after " + f"{max_retries} attempts" + ) + return False + + async def _do_deploy(self) -> "LoadBalancerSlsResource": + """ + Deploy LB endpoint and wait for health. + + Deployment flow: + 1. Validate LB configuration + 2. Call parent deploy (creates endpoint in RunPod) + 3. Poll /ping endpoint until healthy or timeout + 4. Return deployed resource + + Returns: + Deployed LoadBalancerSlsResource instance + + Raises: + ValueError: If LB configuration invalid or deployment fails + TimeoutError: If /ping endpoint doesn't respond in time + """ + # Validate before deploying + self._validate_lb_configuration() + + # Check if already deployed + if self.is_deployed(): + log.debug(f"{self} already deployed") + return self + + try: + # Call parent deploy (creates endpoint via RunPod API) + log.info(f"Deploying LB endpoint {self.name}...") + deployed = await super()._do_deploy() + + # Wait for /ping endpoint to become available + timeout_seconds = ( + DEFAULT_HEALTH_CHECK_RETRIES * DEFAULT_HEALTH_CHECK_INTERVAL + ) + log.info( + f"Endpoint created, waiting for /ping to respond " + f"({timeout_seconds}s timeout)..." + ) + + healthy = await self._wait_for_health( + max_retries=DEFAULT_HEALTH_CHECK_RETRIES, + retry_interval=DEFAULT_HEALTH_CHECK_INTERVAL, + ) + + if not healthy: + raise TimeoutError( + f"LB endpoint {self.name} ({deployed.id}) failed to become " + f"healthy within {timeout_seconds}s" + ) + + log.info(f"LB endpoint {self.name} ({deployed.id}) deployed and healthy") + return deployed + + except Exception as e: + log.error(f"Failed to deploy LB endpoint {self.name}: {e}") + raise + + def is_deployed(self) -> bool: + """ + Override is_deployed to use async health check. + + Note: This is a synchronous wrapper around the async health check. + Prefer is_deployed_async() in async contexts. + + Returns: + True if endpoint is deployed and /ping responds + """ + if not self.id: + return False + + try: + # Try the RunPod SDK health check (works for basic connectivity) + response = self.endpoint.health() + return response is not None + except Exception as e: + log.debug(f"RunPod health check failed for {self.name}: {e}") + return False + + +class CpuLoadBalancerSlsResource(CpuEndpointMixin, LoadBalancerSlsResource): + """CPU-only load-balanced endpoint with automatic disk sizing. + + Similar to LoadBalancerSlsResource but configured for CPU instances + instead of GPUs. Inherits CPU-specific functionality from CpuEndpointMixin + for automatic disk sizing and validation. + + Defaults to CPU_ANY instance type if not specified. + + Configuration example: + mothership = CpuLoadBalancerSlsResource( + name="mothership", + imageName="my-mothership:latest", + env={"FLASH_APP": "my_app"}, + instanceIds=[CpuInstanceType.CPU3G_1_4], + workersMin=1, + workersMax=3, + ) + await mothership.deploy() + """ + + instanceIds: Optional[List[CpuInstanceType]] = [CpuInstanceType.ANY] + + # CPU endpoints exclude GPU-specific fields from API payload + # This prevents the RunPod GraphQL API from rejecting CPU endpoints with GPU-specific fields + _input_only = { + "id", + "cudaVersions", + "datacenter", + "env", + "flashboot", + "gpus", + "gpuIds", + "gpuCount", + "allowedCudaVersions", + "imageName", + "networkVolume", + } + + def _setup_cpu_template(self) -> None: + """Setup template, validating and creating/configuring as needed.""" + if not any([self.imageName, self.template, self.templateId]): + raise ValueError( + "Either imageName, template, or templateId must be provided" + ) + + if not self.templateId and not self.template: + self.template = self._create_new_template() + elif self.template: + self._configure_existing_template() + + @model_validator(mode="after") + def set_serverless_template(self): + """Create template from imageName if not provided. + + Overrides parent to call _sync_cpu_fields first to ensure GPU defaults + are overridden for CPU endpoints. + """ + # Sync CPU-specific fields first (override GPU defaults) + self._sync_cpu_fields() + + # Setup template with validation and creation + self._setup_cpu_template() + + return self diff --git a/src/tetra_rp/core/resources/serverless.py b/src/tetra_rp/core/resources/serverless.py index 4fc33907..8a7e650a 100644 --- a/src/tetra_rp/core/resources/serverless.py +++ b/src/tetra_rp/core/resources/serverless.py @@ -251,6 +251,26 @@ def sync_input_fields(self): return self + def _create_new_template(self) -> PodTemplate: + """Create a new PodTemplate with standard configuration.""" + return PodTemplate( + name=self.resource_id, + imageName=self.imageName, + env=KeyValuePair.from_dict(self.env or get_env_vars()), + ) + + def _configure_existing_template(self) -> None: + """Configure an existing template with necessary overrides.""" + if self.template is None: + return + + self.template.name = f"{self.resource_id}__{self.template.resource_id}" + + if self.imageName: + self.template.imageName = self.imageName + if self.env: + self.template.env = KeyValuePair.from_dict(self.env) + async def _sync_graphql_object_with_inputs( self, returned_endpoint: "ServerlessResource" ): @@ -587,28 +607,12 @@ class ServerlessEndpoint(ServerlessResource): Inherits from ServerlessResource. """ - def _create_new_template(self) -> PodTemplate: - """Create a new PodTemplate with standard configuration.""" - return PodTemplate( - name=self.resource_id, - imageName=self.imageName, - env=KeyValuePair.from_dict(self.env or get_env_vars()), - ) - - def _configure_existing_template(self) -> None: - """Configure an existing template with necessary overrides.""" - if self.template is None: - return - - self.template.name = f"{self.resource_id}__{self.template.resource_id}" - - if self.imageName: - self.template.imageName = self.imageName - if self.env: - self.template.env = KeyValuePair.from_dict(self.env) - @model_validator(mode="after") def set_serverless_template(self): + """Create template from imageName if not provided. + + Must run after sync_input_fields to ensure all input fields are synced. + """ if not any([self.imageName, self.template, self.templateId]): raise ValueError( "Either imageName, template, or templateId must be provided" diff --git a/src/tetra_rp/core/resources/serverless_cpu.py b/src/tetra_rp/core/resources/serverless_cpu.py index 86835655..c4e6fd3b 100644 --- a/src/tetra_rp/core/resources/serverless_cpu.py +++ b/src/tetra_rp/core/resources/serverless_cpu.py @@ -24,6 +24,16 @@ class CpuEndpointMixin: instanceIds: Optional[List[CpuInstanceType]] + @field_validator("instanceIds") + @classmethod + def validate_instance_ids( + cls, value: List[CpuInstanceType] + ) -> List[CpuInstanceType]: + """Expand ANY to all available CPU instance types.""" + if value == [CpuInstanceType.ANY]: + return CpuInstanceType.all() + return value + def _is_cpu_endpoint(self) -> bool: """Check if this is a CPU endpoint (has instanceIds).""" return ( @@ -178,14 +188,6 @@ def _configure_existing_template(self) -> None: # Apply CPU-specific disk sizing self._apply_cpu_disk_sizing(self.template) - @field_validator("instanceIds") - @classmethod - def validate_cpus(cls, value: List[CpuInstanceType]) -> List[CpuInstanceType]: - """Expand ANY to all GPU groups""" - if value == [CpuInstanceType.ANY]: - return CpuInstanceType.all() - return value - @model_validator(mode="after") def set_serverless_template(self): # Sync CPU-specific fields first diff --git a/src/tetra_rp/core/resources/template.py b/src/tetra_rp/core/resources/template.py index a4c0a254..80a74c7c 100644 --- a/src/tetra_rp/core/resources/template.py +++ b/src/tetra_rp/core/resources/template.py @@ -1,6 +1,7 @@ -import requests +import warnings from typing import Dict, List, Optional, Any from pydantic import BaseModel, model_validator +from tetra_rp.core.utils.http import get_authenticated_requests_session from .base import BaseResource @@ -38,7 +39,7 @@ def sync_input_fields(self): def update_system_dependencies( - template_id, token, system_dependencies, base_entry_cmd=None + template_id, token=None, system_dependencies=None, base_entry_cmd=None ): """ Updates Runpod template with system dependencies installed via apt-get, @@ -46,12 +47,20 @@ def update_system_dependencies( Args: template_id (str): Runpod template ID. - token (str): Runpod API token. + token (str): [DEPRECATED] Runpod API token. Ignored; uses RUNPOD_API_KEY env var instead. system_dependencies (List[str]): List of apt packages to install. base_entry_cmd (List[str]): The default command to run the app, e.g. ["uv", "run", "handler.py"] Returns: dict: API response JSON or error info. """ + # Warn if deprecated token parameter is used + if token is not None: + warnings.warn( + "The 'token' parameter is deprecated and ignored. " + "Authentication now uses RUNPOD_API_KEY environment variable.", + DeprecationWarning, + stacklevel=2, + ) # Compose apt-get install command if any packages specified apt_cmd = "" @@ -83,12 +92,16 @@ def update_system_dependencies( "volumeMountPath": "/workspace", } - headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} - url = f"https://rest.runpod.io/v1/templates/{template_id}/update" - response = requests.post(url, json=payload, headers=headers) + # Use centralized auth utility instead of manual header setup + # Note: token parameter is deprecated; uses RUNPOD_API_KEY environment variable + session = get_authenticated_requests_session() try: + response = session.post(url, json=payload) + response.raise_for_status() return response.json() - except Exception: - return {"error": "Invalid JSON response", "text": response.text} + except Exception as e: + return {"error": "Failed to update template", "details": str(e)} + finally: + session.close() diff --git a/src/tetra_rp/core/utils/http.py b/src/tetra_rp/core/utils/http.py new file mode 100644 index 00000000..ac6ac01e --- /dev/null +++ b/src/tetra_rp/core/utils/http.py @@ -0,0 +1,67 @@ +"""HTTP utilities for RunPod API communication.""" + +import os +from typing import Optional + +import httpx +import requests + + +def get_authenticated_httpx_client( + timeout: Optional[float] = None, +) -> httpx.AsyncClient: + """Create httpx AsyncClient with RunPod authentication. + + Automatically includes Authorization header if RUNPOD_API_KEY is set. + This provides a centralized place to manage authentication headers for + all RunPod HTTP requests, avoiding repetitive manual header addition. + + Args: + timeout: Request timeout in seconds. Defaults to 30.0. + + Returns: + Configured httpx.AsyncClient with Authorization header + + Example: + async with get_authenticated_httpx_client() as client: + response = await client.post(url, json=data) + + # With custom timeout + async with get_authenticated_httpx_client(timeout=60.0) as client: + response = await client.get(url) + """ + headers = {} + api_key = os.environ.get("RUNPOD_API_KEY") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + timeout_config = timeout if timeout is not None else 30.0 + return httpx.AsyncClient(timeout=timeout_config, headers=headers) + + +def get_authenticated_requests_session() -> requests.Session: + """Create requests Session with RunPod authentication. + + Automatically includes Authorization header if RUNPOD_API_KEY is set. + Provides a centralized place to manage authentication headers for + synchronous RunPod HTTP requests. + + Returns: + Configured requests.Session with Authorization header + + Example: + session = get_authenticated_requests_session() + response = session.post(url, json=data, timeout=30.0) + # Remember to close: session.close() + + # Or use as context manager + import contextlib + with contextlib.closing(get_authenticated_requests_session()) as session: + response = session.post(url, json=data) + """ + session = requests.Session() + api_key = os.environ.get("RUNPOD_API_KEY") + if api_key: + session.headers["Authorization"] = f"Bearer {api_key}" + + return session diff --git a/src/tetra_rp/runtime/lb_handler.py b/src/tetra_rp/runtime/lb_handler.py new file mode 100644 index 00000000..6f7c198c --- /dev/null +++ b/src/tetra_rp/runtime/lb_handler.py @@ -0,0 +1,197 @@ +"""Factory for creating FastAPI load-balanced handlers. + +This module provides the factory function for generating FastAPI applications +that handle load-balanced serverless endpoints. It supports both user-defined +HTTP routes and the framework's /execute endpoint for @remote function execution. + +Security Model: + The /execute endpoint accepts and executes serialized function code. This is + secure because: + 1. The function code originates from the client's @remote decorator + 2. The client (user) controls what function gets sent + 3. This mirrors the trusted client model of LiveServerlessStub + 4. In production, API authentication should protect the /execute endpoint + + Users should NOT expose the /execute endpoint to untrusted clients. +""" + +import base64 +import inspect +import logging +from typing import Any, Callable, Dict + +import cloudpickle +from fastapi import FastAPI, Request + +logger = logging.getLogger(__name__) + + +def create_lb_handler( + route_registry: Dict[tuple[str, str], Callable], include_execute: bool = False +) -> FastAPI: + """Create FastAPI app with routes from registry. + + Args: + route_registry: Mapping of (HTTP_METHOD, path) -> handler_function + Example: {("GET", "/api/health"): health_check} + include_execute: Whether to register /execute endpoint for @remote execution. + Only used for LiveLoadBalancer (local development). + Deployed endpoints should not expose /execute for security. + + Returns: + Configured FastAPI application with routes registered. + """ + app = FastAPI(title="Flash Load-Balanced Handler") + + # Register /execute endpoint for @remote stub execution (if enabled) + if include_execute: + + @app.post("/execute") + async def execute_remote_function(request: Request) -> Dict[str, Any]: + """Framework endpoint for @remote decorator execution. + + WARNING: This endpoint is INTERNAL to the Flash framework. It should only be + called by the @remote stub from tetra_rp.stubs.load_balancer_sls. Exposing + this endpoint to untrusted clients could allow arbitrary code execution. + + Accepts serialized function code and arguments, executes them, + and returns serialized result. + + Request body: + { + "function_name": "process_data", + "function_code": "def process_data(x, y): return x + y", + "args": [base64_encoded_arg1, base64_encoded_arg2], + "kwargs": {"key": base64_encoded_value} + } + + Returns: + { + "success": true, + "result": base64_encoded_result + } + or + { + "success": false, + "error": "error message" + } + """ + try: + body = await request.json() + except Exception as e: + logger.error(f"Failed to parse request body: {e}") + return {"success": False, "error": f"Invalid request body: {e}"} + + try: + # Extract function metadata + function_name = body.get("function_name") + function_code = body.get("function_code") + + if not function_name or not function_code: + return { + "success": False, + "error": "Missing function_name or function_code in request", + } + + # Deserialize arguments + args = [] + for arg_b64 in body.get("args", []): + try: + arg = cloudpickle.loads(base64.b64decode(arg_b64)) + args.append(arg) + except Exception as e: + logger.error(f"Failed to deserialize argument: {e}") + return { + "success": False, + "error": f"Failed to deserialize argument: {e}", + } + + kwargs = {} + for key, val_b64 in body.get("kwargs", {}).items(): + try: + val = cloudpickle.loads(base64.b64decode(val_b64)) + kwargs[key] = val + except Exception as e: + logger.error(f"Failed to deserialize kwarg '{key}': {e}") + return { + "success": False, + "error": f"Failed to deserialize kwarg '{key}': {e}", + } + + # Execute function in isolated namespace + namespace: Dict[str, Any] = {} + try: + exec(function_code, namespace) + except SyntaxError as e: + logger.error(f"Syntax error in function code: {e}") + return { + "success": False, + "error": f"Syntax error in function code: {e}", + } + except Exception as e: + logger.error(f"Error executing function code: {e}") + return { + "success": False, + "error": f"Error executing function code: {e}", + } + + # Get function from namespace + if function_name not in namespace: + return { + "success": False, + "error": f"Function '{function_name}' not found in executed code", + } + + func = namespace[function_name] + + # Execute function + try: + result = func(*args, **kwargs) + + # Handle async functions + if inspect.iscoroutine(result): + result = await result + except Exception as e: + logger.error(f"Function execution failed: {e}") + return { + "success": False, + "error": f"Function execution failed: {e}", + } + + # Serialize result + try: + result_b64 = base64.b64encode(cloudpickle.dumps(result)).decode( + "utf-8" + ) + return {"success": True, "result": result_b64} + except Exception as e: + logger.error(f"Failed to serialize result: {e}") + return { + "success": False, + "error": f"Failed to serialize result: {e}", + } + + except Exception as e: + logger.error(f"Unexpected error in /execute endpoint: {e}") + return {"success": False, "error": f"Unexpected error: {e}"} + + # Register user-defined routes from registry + for (method, path), handler in route_registry.items(): + method_upper = method.upper() + + if method_upper == "GET": + app.get(path)(handler) + elif method_upper == "POST": + app.post(path)(handler) + elif method_upper == "PUT": + app.put(path)(handler) + elif method_upper == "DELETE": + app.delete(path)(handler) + elif method_upper == "PATCH": + app.patch(path)(handler) + else: + logger.warning( + f"Unsupported HTTP method '{method}' for path '{path}'. Skipping." + ) + + return app diff --git a/src/tetra_rp/stubs/load_balancer_sls.py b/src/tetra_rp/stubs/load_balancer_sls.py new file mode 100644 index 00000000..1edd1a90 --- /dev/null +++ b/src/tetra_rp/stubs/load_balancer_sls.py @@ -0,0 +1,361 @@ +"""LoadBalancerSlsStub - Stub for load-balanced serverless execution. + +Enables @remote decorator to work with LoadBalancerSlsResource endpoints +via direct HTTP calls instead of queue-based job submission. +""" + +import base64 +import inspect +import logging +from typing import Any, Callable, Dict, List, Optional + +import httpx +import cloudpickle + +from tetra_rp.core.utils.http import get_authenticated_httpx_client +from .live_serverless import get_function_source + +log = logging.getLogger(__name__) + + +class LoadBalancerSlsStub: + """HTTP-based stub for load-balanced serverless endpoint execution. + + Implements the stub interface for @remote decorator with LoadBalancerSlsResource, + providing direct HTTP-based function execution instead of queue-based processing. + + Key differences from LiveServerlessStub: + - Direct HTTP POST to /execute endpoint (not queue-based) + - No job ID polling - synchronous HTTP response + - Same function serialization pattern (cloudpickle + base64) + - Lower latency but no automatic retries + + Architecture: + 1. User calls @remote decorated function + 2. Decorator dispatches to this stub via singledispatch + 3. Stub serializes function code and arguments + 4. Stub POSTs to endpoint /execute with serialized data + 5. Endpoint deserializes, executes, and returns result + 6. Stub deserializes result and returns to user + + Example: + stub = LoadBalancerSlsStub(lb_resource) + result = await stub(my_func, deps, sys_deps, accel, arg1, arg2) + """ + + DEFAULT_TIMEOUT = 30.0 # Default timeout in seconds + + def __init__(self, server: Any, timeout: Optional[float] = None) -> None: + """Initialize stub with LoadBalancerSlsResource server. + + Args: + server: LoadBalancerSlsResource instance with endpoint_url configured + timeout: Request timeout in seconds (default: 30.0) + """ + self.server = server + self.timeout = timeout if timeout is not None else self.DEFAULT_TIMEOUT + + def _should_use_execute_endpoint(self, func: Callable[..., Any]) -> bool: + """Determine if /execute endpoint should be used for this function. + + The /execute endpoint (which accepts arbitrary function code) is only used for: + - LiveLoadBalancer (local development) + - Functions without routing metadata (backward compatibility) + + For deployed LoadBalancerSlsResource endpoints with routing metadata, + the stub translates @remote calls into HTTP requests to user-defined routes. + + Args: + func: Function being called + + Returns: + True if /execute should be used, False if user route should be used + """ + from ..core.resources.live_serverless import LiveLoadBalancer + + # Always use /execute for LiveLoadBalancer (local development) + if isinstance(self.server, LiveLoadBalancer): + log.debug(f"Using /execute endpoint for LiveLoadBalancer: {func.__name__}") + return True + + # Check if function has routing metadata + routing_config = getattr(func, "__remote_config__", None) + if not routing_config: + log.debug(f"No routing config for {func.__name__}, using /execute fallback") + return True + + # Check if routing metadata is complete + if not routing_config.get("method") or not routing_config.get("path"): + log.debug( + f"Incomplete routing config for {func.__name__}, using /execute fallback" + ) + return True + + # Use user-defined route for deployed endpoints with complete routing metadata + log.debug( + f"Using user route for deployed endpoint: {func.__name__} " + f"{routing_config['method']} {routing_config['path']}" + ) + return False + + async def __call__( + self, + func: Callable[..., Any], + dependencies: Optional[List[str]], + system_dependencies: Optional[List[str]], + accelerate_downloads: bool, + *args: Any, + **kwargs: Any, + ) -> Any: + """Execute function on load-balanced endpoint. + + Behavior depends on endpoint type: + - LiveLoadBalancer: Uses /execute endpoint (local development) + - Deployed LoadBalancerSlsResource: Uses user-defined route via HTTP + + Args: + func: Function to execute + dependencies: Pip dependencies required + system_dependencies: System dependencies required + accelerate_downloads: Whether to accelerate downloads + *args: Function positional arguments + **kwargs: Function keyword arguments + + Returns: + Function result + + Raises: + Exception: If endpoint returns error or HTTP call fails + """ + # Determine execution path based on resource type and routing metadata + if self._should_use_execute_endpoint(func): + # Local development or backward compatibility: use /execute endpoint + request = self._prepare_request( + func, + dependencies, + system_dependencies, + accelerate_downloads, + *args, + **kwargs, + ) + response = await self._execute_function(request) + return self._handle_response(response) + else: + # Deployed endpoint: use user-defined route + routing_config = func.__remote_config__ + return await self._execute_via_user_route( + func, + routing_config["method"], + routing_config["path"], + *args, + **kwargs, + ) + + def _prepare_request( + self, + func: Callable[..., Any], + dependencies: Optional[List[str]], + system_dependencies: Optional[List[str]], + accelerate_downloads: bool, + *args: Any, + **kwargs: Any, + ) -> Dict[str, Any]: + """Prepare HTTP request payload. + + Extracts function source code and serializes arguments using cloudpickle. + + Args: + func: Function to serialize + dependencies: Pip dependencies + system_dependencies: System dependencies + accelerate_downloads: Download acceleration flag + *args: Function arguments + **kwargs: Function keyword arguments + + Returns: + Request dictionary with serialized function and arguments + """ + source, _ = get_function_source(func) + log.debug(f"Extracted source for {func.__name__} ({len(source)} bytes)") + + request = { + "function_name": func.__name__, + "function_code": source, + "dependencies": dependencies or [], + "system_dependencies": system_dependencies or [], + "accelerate_downloads": accelerate_downloads, + } + + # Serialize arguments using cloudpickle + base64 + if args: + serialized_args = [ + base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8") for arg in args + ] + request["args"] = serialized_args + log.debug(f"Serialized {len(args)} positional args for {func.__name__}") + + if kwargs: + serialized_kwargs = { + k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8") + for k, v in kwargs.items() + } + request["kwargs"] = serialized_kwargs + log.debug(f"Serialized {len(kwargs)} keyword args for {func.__name__}") + + return request + + async def _execute_function(self, request: Dict[str, Any]) -> Dict[str, Any]: + """Execute function via direct HTTP POST to endpoint. + + Posts serialized function and arguments to /execute endpoint. + No job ID polling - waits for synchronous HTTP response. + + Args: + request: Request dictionary with function_code, args, kwargs + + Returns: + Response dictionary with success flag and result + + Raises: + httpx.HTTPError: If HTTP request fails + ValueError: If endpoint_url not available + """ + if not self.server.endpoint_url: + raise ValueError( + "Endpoint URL not available - endpoint may not be deployed" + ) + + execute_url = f"{self.server.endpoint_url}/execute" + + try: + async with get_authenticated_httpx_client(timeout=self.timeout) as client: + response = await client.post(execute_url, json=request) + response.raise_for_status() + return response.json() + except httpx.TimeoutException as e: + raise TimeoutError( + f"Execution timeout on {self.server.name} after {self.timeout}s: {e}" + ) from e + except httpx.HTTPStatusError as e: + # Truncate response body to prevent huge error messages + response_text = e.response.text + if len(response_text) > 500: + response_text = response_text[:500] + "... (truncated)" + raise RuntimeError( + f"HTTP error from endpoint {self.server.name}: " + f"{e.response.status_code} - {response_text}" + ) from e + except httpx.RequestError as e: + raise ConnectionError( + f"Failed to connect to endpoint {self.server.name} ({execute_url}): {e}" + ) from e + + async def _execute_via_user_route( + self, + func: Callable[..., Any], + method: str, + path: str, + *args: Any, + **kwargs: Any, + ) -> Any: + """Execute function by calling user-defined HTTP route. + + Maps function arguments to JSON request body and makes HTTP request + to the user-defined route. The response is parsed as JSON and returned directly. + + Args: + func: Function being called (used for signature inspection) + method: HTTP method (GET, POST, PUT, DELETE, PATCH) + path: URL path (e.g., /api/process) + *args: Function positional arguments + **kwargs: Function keyword arguments + + Returns: + Function result (parsed from JSON response) + + Raises: + ValueError: If endpoint_url not available + TimeoutError: If request times out + RuntimeError: If HTTP error occurs + ConnectionError: If connection fails + """ + if not self.server.endpoint_url: + raise ValueError( + "Endpoint URL not available - endpoint may not be deployed" + ) + + # Get function signature to map args to parameter names + sig = inspect.signature(func) + params = list(sig.parameters.keys()) + + # Map positional args to parameter names + body = {} + for i, arg in enumerate(args): + if i < len(params): + body[params[i]] = arg + body.update(kwargs) + + # Construct full URL + url = f"{self.server.endpoint_url}{path}" + log.debug(f"Executing via user route: {method} {url}") + + try: + async with get_authenticated_httpx_client(timeout=self.timeout) as client: + response = await client.request(method, url, json=body) + response.raise_for_status() + result = response.json() + log.debug( + f"User route execution successful (type={type(result).__name__})" + ) + return result + except httpx.TimeoutException as e: + raise TimeoutError( + f"Execution timeout on {self.server.name} after {self.timeout}s: {e}" + ) from e + except httpx.HTTPStatusError as e: + # Truncate response body to prevent huge error messages + response_text = e.response.text + if len(response_text) > 500: + response_text = response_text[:500] + "... (truncated)" + raise RuntimeError( + f"HTTP error from endpoint {self.server.name}: " + f"{e.response.status_code} - {response_text}" + ) from e + except httpx.RequestError as e: + raise ConnectionError( + f"Failed to connect to endpoint {self.server.name} ({url}): {e}" + ) from e + + def _handle_response(self, response: Dict[str, Any]) -> Any: + """Deserialize and validate response. + + Args: + response: Response dictionary from endpoint + + Returns: + Deserialized function result + + Raises: + ValueError: If response format is invalid + Exception: If response indicates error + """ + if not isinstance(response, dict): + raise ValueError(f"Invalid response type: {type(response)}") + + if response.get("success"): + result_b64 = response.get("result") + if result_b64 is None: + raise ValueError("Response marked success but result is None") + + try: + result = cloudpickle.loads(base64.b64decode(result_b64)) + log.debug( + f"Successfully deserialized response result (type={type(result).__name__})" + ) + return result + except Exception as e: + raise ValueError(f"Failed to deserialize result: {e}") from e + else: + error = response.get("error", "Unknown error") + log.warning(f"Remote execution failed: {error}") + raise Exception(f"Remote execution failed: {error}") diff --git a/src/tetra_rp/stubs/registry.py b/src/tetra_rp/stubs/registry.py index 9ea94f45..c6363726 100644 --- a/src/tetra_rp/stubs/registry.py +++ b/src/tetra_rp/stubs/registry.py @@ -4,10 +4,13 @@ from ..core.resources import ( CpuLiveServerless, CpuServerlessEndpoint, + LiveLoadBalancer, LiveServerless, + LoadBalancerSlsResource, ServerlessEndpoint, ) from .live_serverless import LiveServerlessStub +from .load_balancer_sls import LoadBalancerSlsStub from .serverless import ServerlessEndpointStub log = logging.getLogger(__name__) @@ -115,3 +118,53 @@ async def stubbed_resource( return stub.handle_response(response) return stubbed_resource + + +@stub_resource.register(LoadBalancerSlsResource) +def _(resource, **extra): + """Create stub for LoadBalancerSlsResource (HTTP-based execution).""" + stub = LoadBalancerSlsStub(resource) + + async def stubbed_resource( + func, + dependencies, + system_dependencies, + accelerate_downloads, + *args, + **kwargs, + ) -> dict: + return await stub( + func, + dependencies, + system_dependencies, + accelerate_downloads, + *args, + **kwargs, + ) + + return stubbed_resource + + +@stub_resource.register(LiveLoadBalancer) +def _(resource, **extra): + """Create stub for LiveLoadBalancer (HTTP-based execution, local testing).""" + stub = LoadBalancerSlsStub(resource) + + async def stubbed_resource( + func, + dependencies, + system_dependencies, + accelerate_downloads, + *args, + **kwargs, + ) -> dict: + return await stub( + func, + dependencies, + system_dependencies, + accelerate_downloads, + *args, + **kwargs, + ) + + return stubbed_resource diff --git a/tests/integration/test_lb_remote_execution.py b/tests/integration/test_lb_remote_execution.py new file mode 100644 index 00000000..4d34abf3 --- /dev/null +++ b/tests/integration/test_lb_remote_execution.py @@ -0,0 +1,306 @@ +"""Integration tests for @remote with LoadBalancerSlsResource. + +These tests verify the full flow of using @remote with load-balanced endpoints, +including local development with LiveLoadBalancer and HTTP execution. +""" + +import base64 +import pytest +from unittest.mock import MagicMock + +import cloudpickle + +from tetra_rp import remote, LiveLoadBalancer, LoadBalancerSlsResource + + +class TestRemoteWithLoadBalancerIntegration: + """Integration tests for @remote decorator with LB endpoints.""" + + def test_decorator_accepts_lb_resource_with_routing(self): + """Test that @remote accepts LoadBalancerSlsResource with method/path.""" + lb = LoadBalancerSlsResource(name="test-api", imageName="test:latest") + + @remote(lb, method="POST", path="/api/process") + async def process_data(x: int, y: int): + return {"result": x + y} + + # Should not raise - decorator accepts the parameters + assert hasattr(process_data, "__remote_config__") + assert process_data.__remote_config__["method"] == "POST" + assert process_data.__remote_config__["path"] == "/api/process" + + def test_decorator_validates_method_and_path_required(self): + """Test that @remote requires both method and path for LB resources.""" + lb = LoadBalancerSlsResource(name="test-api", imageName="test:latest") + + with pytest.raises(ValueError, match="requires both 'method' and 'path'"): + + @remote(lb) + async def missing_routing(): + pass + + def test_decorator_validates_invalid_http_method(self): + """Test that @remote rejects invalid HTTP methods.""" + lb = LoadBalancerSlsResource(name="test-api", imageName="test:latest") + + with pytest.raises(ValueError, match="must be one of"): + + @remote(lb, method="INVALID", path="/api/test") + async def bad_method(): + pass + + def test_decorator_validates_path_starts_with_slash(self): + """Test that @remote requires path to start with /.""" + lb = LoadBalancerSlsResource(name="test-api", imageName="test:latest") + + with pytest.raises(ValueError, match="must start with '/'"): + + @remote(lb, method="GET", path="api/test") + async def bad_path(): + pass + + @pytest.mark.asyncio + async def test_remote_function_serialization_roundtrip(self): + """Test that function code and args serialize/deserialize correctly.""" + from tetra_rp.stubs.load_balancer_sls import LoadBalancerSlsStub + + mock_resource = MagicMock() + stub = LoadBalancerSlsStub(mock_resource) + + def add(x: int, y: int) -> int: + """Simple add function.""" + return x + y + + # Prepare request + request = stub._prepare_request(add, None, None, True, 5, 3) + + # Verify request structure + assert request["function_name"] == "add" + assert "def add" in request["function_code"] + assert len(request["args"]) == 2 + + # Deserialize and verify arguments + arg0 = cloudpickle.loads(base64.b64decode(request["args"][0])) + arg1 = cloudpickle.loads(base64.b64decode(request["args"][1])) + assert arg0 == 5 + assert arg1 == 3 + + @pytest.mark.asyncio + async def test_stub_response_deserialization(self): + """Test that response deserialization works correctly.""" + from tetra_rp.stubs.load_balancer_sls import LoadBalancerSlsStub + + mock_resource = MagicMock() + stub = LoadBalancerSlsStub(mock_resource) + + result_value = {"status": "success", "count": 42} + result_b64 = base64.b64encode(cloudpickle.dumps(result_value)).decode("utf-8") + + response = {"success": True, "result": result_b64} + + # Handle response + result = stub._handle_response(response) + + assert result == result_value + + def test_live_load_balancer_creation(self): + """Test that LiveLoadBalancer can be created and used with @remote.""" + lb = LiveLoadBalancer(name="test-live-api") + + @remote(lb, method="POST", path="/api/echo") + async def echo(message: str): + return {"echo": message} + + # Verify resource is correctly configured + # Note: name may have "-fb" appended by flash boot validator + assert "test-live-api" in lb.name + assert "tetra-rp-lb" in lb.imageName + assert echo.__remote_config__["method"] == "POST" + + def test_live_load_balancer_image_locked(self): + """Test that LiveLoadBalancer locks the image to Tetra LB image.""" + lb = LiveLoadBalancer(name="test-api") + + # Verify image is locked and cannot be overridden + original_image = lb.imageName + assert "tetra-rp-lb" in original_image + + # Try to set a different image (should be ignored due to property) + lb.imageName = "custom-image:latest" + + # Image should still be locked to Tetra + assert lb.imageName == original_image + + def test_load_balancer_vs_queue_based_endpoints(self): + """Test that LB and QB endpoints have different characteristics.""" + from tetra_rp import ServerlessEndpoint + + lb = LoadBalancerSlsResource(name="lb-api", imageName="test:latest") + qb = ServerlessEndpoint(name="qb-api", imageName="test:latest") + + @remote(lb, method="POST", path="/api/echo") + async def lb_func(): + return "lb" + + @remote(qb) + async def qb_func(): + return "qb" + + # Both should have __remote_config__ + assert hasattr(lb_func, "__remote_config__") + assert hasattr(qb_func, "__remote_config__") + + # LB should have routing config + assert lb_func.__remote_config__["method"] == "POST" + assert lb_func.__remote_config__["path"] == "/api/echo" + + # QB should have None values for routing (not LB-specific) + assert qb_func.__remote_config__["method"] is None + assert qb_func.__remote_config__["path"] is None + + def test_live_load_balancer_handler_includes_execute_endpoint(self): + """Test that generated handler for LiveLoadBalancer includes /execute endpoint.""" + from tetra_rp.cli.commands.build_utils.lb_handler_generator import ( + LBHandlerGenerator, + ) + from datetime import datetime, timezone + from pathlib import Path + import tempfile + + # Create a manifest for LiveLoadBalancer + manifest = { + "version": "1.0", + "generated_at": datetime.now(timezone.utc) + .isoformat() + .replace("+00:00", "Z"), + "project_name": "test-project", + "resources": { + "test-api": { + "resource_type": "LiveLoadBalancer", + "handler_file": "handler_test_api.py", + "functions": [ + { + "name": "process_data", + "module": "api.endpoints", + "is_async": True, + "is_class": False, + "http_method": "POST", + "http_path": "/api/process", + } + ], + } + }, + } + + with tempfile.TemporaryDirectory() as tmpdir: + build_dir = Path(tmpdir) + generator = LBHandlerGenerator(manifest, build_dir) + handlers = generator.generate_handlers() + + assert len(handlers) == 1 + handler_path = handlers[0] + handler_code = handler_path.read_text() + + # Verify the handler includes include_execute=True for LiveLoadBalancer + assert "include_execute=True" in handler_code + assert ( + "create_lb_handler(ROUTE_REGISTRY, include_execute=True)" + in handler_code + ) + + def test_deployed_load_balancer_handler_excludes_execute_endpoint(self): + """Test that generated handler for deployed LoadBalancerSlsResource excludes /execute endpoint.""" + from tetra_rp.cli.commands.build_utils.lb_handler_generator import ( + LBHandlerGenerator, + ) + from datetime import datetime, timezone + from pathlib import Path + import tempfile + + # Create a manifest for deployed LoadBalancerSlsResource + manifest = { + "version": "1.0", + "generated_at": datetime.now(timezone.utc) + .isoformat() + .replace("+00:00", "Z"), + "project_name": "test-project", + "resources": { + "api-service": { + "resource_type": "LoadBalancerSlsResource", + "handler_file": "handler_api_service.py", + "functions": [ + { + "name": "process_data", + "module": "api.endpoints", + "is_async": True, + "is_class": False, + "http_method": "POST", + "http_path": "/api/process", + } + ], + } + }, + } + + with tempfile.TemporaryDirectory() as tmpdir: + build_dir = Path(tmpdir) + generator = LBHandlerGenerator(manifest, build_dir) + handlers = generator.generate_handlers() + + assert len(handlers) == 1 + handler_path = handlers[0] + handler_code = handler_path.read_text() + + # Verify the handler includes include_execute=False for deployed endpoints + assert "include_execute=False" in handler_code + assert ( + "create_lb_handler(ROUTE_REGISTRY, include_execute=False)" + in handler_code + ) + + def test_scanner_discovers_load_balancer_resources(self): + """Test that scanner can discover LiveLoadBalancer and LoadBalancerSlsResource.""" + from tetra_rp.cli.commands.build_utils.scanner import RemoteDecoratorScanner + from pathlib import Path + import tempfile + + # Create temporary Python file with LoadBalancer resource + code = """ +from tetra_rp import LiveLoadBalancer, LoadBalancerSlsResource, remote + +# Test LiveLoadBalancer discovery +api = LiveLoadBalancer(name="test-api") + +@remote(api, method="POST", path="/api/process") +async def process_data(x: int): + return {"result": x} + +# Test LoadBalancerSlsResource discovery +deployed = LoadBalancerSlsResource(name="deployed-api", imageName="test:latest") + +@remote(deployed, method="GET", path="/api/status") +def get_status(): + return {"status": "ok"} +""" + + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) + py_file = project_dir / "test_api.py" + py_file.write_text(code) + + scanner = RemoteDecoratorScanner(project_dir) + functions = scanner.discover_remote_functions() + + # Verify both resources were discovered + assert len(functions) == 2 + + # Verify resource types are correctly identified + resource_types = {f.resource_type for f in functions} + assert "LiveLoadBalancer" in resource_types + assert "LoadBalancerSlsResource" in resource_types + + # Verify resource configs were extracted + assert "test-api" in scanner.resource_types + assert scanner.resource_types["test-api"] == "LiveLoadBalancer" + assert "deployed-api" in scanner.resource_types + assert scanner.resource_types["deployed-api"] == "LoadBalancerSlsResource" diff --git a/tests/unit/core/utils/test_http.py b/tests/unit/core/utils/test_http.py new file mode 100644 index 00000000..d26c0954 --- /dev/null +++ b/tests/unit/core/utils/test_http.py @@ -0,0 +1,125 @@ +"""Tests for HTTP utilities for RunPod API communication.""" + +import requests +from tetra_rp.core.utils.http import ( + get_authenticated_httpx_client, + get_authenticated_requests_session, +) + + +class TestGetAuthenticatedHttpxClient: + """Test the get_authenticated_httpx_client utility function.""" + + def test_get_authenticated_httpx_client_with_api_key(self, monkeypatch): + """Test client includes auth header when API key is set.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-api-key-123") + + client = get_authenticated_httpx_client() + + assert client is not None + assert "Authorization" in client.headers + assert client.headers["Authorization"] == "Bearer test-api-key-123" + + def test_get_authenticated_httpx_client_without_api_key(self, monkeypatch): + """Test client works without API key (no auth header).""" + monkeypatch.delenv("RUNPOD_API_KEY", raising=False) + + client = get_authenticated_httpx_client() + + assert client is not None + assert "Authorization" not in client.headers + + def test_get_authenticated_httpx_client_custom_timeout(self, monkeypatch): + """Test client respects custom timeout.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-key") + + client = get_authenticated_httpx_client(timeout=60.0) + + assert client is not None + assert client.timeout.read == 60.0 + + def test_get_authenticated_httpx_client_default_timeout(self, monkeypatch): + """Test client uses default timeout when not specified.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-key") + + client = get_authenticated_httpx_client() + + assert client is not None + assert client.timeout.read == 30.0 + + def test_get_authenticated_httpx_client_timeout_none_uses_default( + self, monkeypatch + ): + """Test client uses default timeout when explicitly passed None.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-key") + + client = get_authenticated_httpx_client(timeout=None) + + assert client is not None + assert client.timeout.read == 30.0 + + def test_get_authenticated_httpx_client_empty_api_key_no_header(self, monkeypatch): + """Test that empty API key doesn't add Authorization header.""" + monkeypatch.setenv("RUNPOD_API_KEY", "") + + client = get_authenticated_httpx_client() + + assert client is not None + # Empty string is falsy, so no auth header should be added + assert "Authorization" not in client.headers + + def test_get_authenticated_httpx_client_zero_timeout(self, monkeypatch): + """Test client handles zero timeout correctly.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-key") + + client = get_authenticated_httpx_client(timeout=0.0) + + assert client is not None + assert client.timeout.read == 0.0 + + +class TestGetAuthenticatedRequestsSession: + """Test the get_authenticated_requests_session utility function.""" + + def test_get_authenticated_requests_session_with_api_key(self, monkeypatch): + """Test session includes auth header when API key is set.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-api-key-123") + + session = get_authenticated_requests_session() + + assert session is not None + assert "Authorization" in session.headers + assert session.headers["Authorization"] == "Bearer test-api-key-123" + session.close() + + def test_get_authenticated_requests_session_without_api_key(self, monkeypatch): + """Test session works without API key (no auth header).""" + monkeypatch.delenv("RUNPOD_API_KEY", raising=False) + + session = get_authenticated_requests_session() + + assert session is not None + assert "Authorization" not in session.headers + session.close() + + def test_get_authenticated_requests_session_empty_api_key_no_header( + self, monkeypatch + ): + """Test that empty API key doesn't add Authorization header.""" + monkeypatch.setenv("RUNPOD_API_KEY", "") + + session = get_authenticated_requests_session() + + assert session is not None + # Empty string is falsy, so no auth header should be added + assert "Authorization" not in session.headers + session.close() + + def test_get_authenticated_requests_session_is_valid_session(self, monkeypatch): + """Test returned object is a valid requests.Session.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-key") + + session = get_authenticated_requests_session() + + assert isinstance(session, requests.Session) + session.close() diff --git a/tests/unit/resources/test_cpu_load_balancer.py b/tests/unit/resources/test_cpu_load_balancer.py new file mode 100644 index 00000000..92a0955d --- /dev/null +++ b/tests/unit/resources/test_cpu_load_balancer.py @@ -0,0 +1,330 @@ +""" +Tests for CpuLoadBalancerSlsResource CPU-specific functionality. + +Ensures CPU load balancers exclude GPU-specific fields from RunPod API payloads +and override GPU defaults to CPU-appropriate values. +""" + +import os + +from tetra_rp.core.resources.cpu import CpuInstanceType +from tetra_rp.core.resources.load_balancer_sls_resource import ( + CpuLoadBalancerSlsResource, +) +from tetra_rp.core.resources.serverless import ServerlessType, ServerlessScalerType +from tetra_rp.core.resources.serverless_cpu import CpuServerlessEndpoint + +# Set a dummy API key for tests that create ResourceManager instances +os.environ.setdefault("RUNPOD_API_KEY", "test-key-for-unit-tests") + + +class TestCpuLoadBalancerDefaults: + """Test CpuLoadBalancerSlsResource default configuration.""" + + def test_cpu_load_balancer_creation_with_defaults(self): + """Test creating CpuLoadBalancerSlsResource with minimal config.""" + lb = CpuLoadBalancerSlsResource( + name="test-cpu-lb", + imageName="test/image:latest", + ) + + assert lb.name == "test-cpu-lb-fb" + assert lb.imageName == "test/image:latest" + assert lb.type == ServerlessType.LB + assert lb.scalerType == ServerlessScalerType.REQUEST_COUNT + + def test_cpu_load_balancer_with_custom_instances(self): + """Test explicit CPU instance type configuration.""" + lb = CpuLoadBalancerSlsResource( + name="test-cpu-lb", + imageName="test/image:latest", + instanceIds=[CpuInstanceType.CPU3G_1_4, CpuInstanceType.CPU3G_2_8], + ) + + assert lb.instanceIds == [CpuInstanceType.CPU3G_1_4, CpuInstanceType.CPU3G_2_8] + + def test_cpu_load_balancer_any_expansion(self): + """Test CpuInstanceType.ANY expansion.""" + lb = CpuLoadBalancerSlsResource( + name="test-cpu-lb", + imageName="test/image:latest", + instanceIds=[CpuInstanceType.ANY], + ) + + # ANY should expand to all CPU types + assert lb.instanceIds == CpuInstanceType.all() + assert CpuInstanceType.ANY not in lb.instanceIds + assert len(lb.instanceIds) == 12 + + +class TestCpuLoadBalancerGpuFieldOverride: + """Test that GPU fields are correctly overridden to CPU defaults.""" + + def test_sync_cpu_fields_overrides_gpu_defaults(self): + """Test _sync_cpu_fields overrides GPU defaults to CPU values.""" + lb = CpuLoadBalancerSlsResource( + name="test-cpu-lb", + imageName="test/image:latest", + instanceIds=[CpuInstanceType.CPU3G_1_4], + ) + + # GPU fields should be overridden to CPU defaults + assert lb.gpuCount == 0, "gpuCount should be 0 for CPU endpoints" + assert lb.allowedCudaVersions == "", "allowedCudaVersions should be empty" + assert lb.gpuIds == "", "gpuIds should be empty" + + def test_gpu_fields_not_hardcoded_in_constructor(self): + """Test that GPU fields are overridden even if passed to constructor.""" + # Attempting to set GPU-specific fields should be overridden + lb = CpuLoadBalancerSlsResource( + name="test-cpu-lb", + imageName="test/image:latest", + instanceIds=[CpuInstanceType.CPU3G_1_4], + gpuCount=4, # Should be overridden + allowedCudaVersions="12.0", # Should be overridden + ) + + assert lb.gpuCount == 0 + assert lb.allowedCudaVersions == "" + + +class TestCpuLoadBalancerInputOnlyExclusion: + """Test that _input_only set contains all GPU-specific fields.""" + + def test_input_only_contains_gpu_fields(self): + """Test _input_only set contains all GPU-specific fields.""" + lb = CpuLoadBalancerSlsResource( + name="test-cpu-lb", + imageName="test/image:latest", + ) + + # Critical fields that must be excluded + required_excludes = { + "gpuCount", + "allowedCudaVersions", + "gpuIds", + "cudaVersions", + "gpus", + } + for field in required_excludes: + assert field in lb._input_only, f"{field} must be in _input_only" + + def test_input_only_includes_common_fields(self): + """Test _input_only includes expected common fields.""" + lb = CpuLoadBalancerSlsResource( + name="test-cpu-lb", + imageName="test/image:latest", + ) + + expected_fields = { + "id", + "datacenter", + "env", + "flashboot", + "imageName", + "networkVolume", + } + for field in expected_fields: + assert field in lb._input_only + + +class TestCpuLoadBalancerPayloadExclusion: + """Test that GPU fields are excluded from model_dump payload.""" + + def test_model_dump_excludes_gpu_fields_from_payload(self): + """Test model_dump payload excludes GPU fields from API.""" + lb = CpuLoadBalancerSlsResource( + name="test-cpu-lb", + imageName="test/image:latest", + instanceIds=[CpuInstanceType.CPU3G_1_4], + ) + + payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") + + # GPU fields must NOT be in payload + assert "gpuCount" not in payload, "gpuCount should be excluded from payload" + assert "allowedCudaVersions" not in payload, ( + "allowedCudaVersions should be excluded" + ) + assert "gpuIds" not in payload, "gpuIds should be excluded" + assert "cudaVersions" not in payload, "cudaVersions should be excluded" + assert "gpus" not in payload, "gpus should be excluded" + + def test_model_dump_includes_cpu_fields_in_payload(self): + """Test model_dump payload includes CPU-specific fields.""" + lb = CpuLoadBalancerSlsResource( + name="test-cpu-lb", + imageName="test/image:latest", + instanceIds=[CpuInstanceType.CPU3G_1_4], + ) + + payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") + + # CPU fields must be in payload + assert "instanceIds" in payload + assert payload["instanceIds"] == ["cpu3g-1-4"] + + def test_model_dump_contains_required_lb_fields(self): + """Test model_dump includes required load balancer fields.""" + lb = CpuLoadBalancerSlsResource( + name="prod-api", + imageName="myapp/api:v1", + instanceIds=[CpuInstanceType.CPU3G_1_4], + workersMin=1, + workersMax=5, + ) + + payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") + + # Required LB fields + assert payload["name"] == "prod-api-fb" + assert payload["type"] == "LB" + assert payload["scalerType"] == "REQUEST_COUNT" + assert payload["workersMin"] == 1 + assert payload["workersMax"] == 5 + + def test_model_dump_excludes_template_image_name(self): + """Test imageName is excluded (sent via template object).""" + lb = CpuLoadBalancerSlsResource( + name="test", + imageName="test/image:latest", + ) + + payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") + + # imageName should be excluded (it's template-specific) + assert "imageName" not in payload + + def test_model_dump_includes_template_object(self): + """Test template object is included in payload.""" + lb = CpuLoadBalancerSlsResource( + name="test", + imageName="test/image:latest", + ) + + payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") + + # Template object should be present + assert "template" in payload + assert isinstance(payload["template"], dict) + assert "imageName" in payload["template"] + + +class TestCpuLoadBalancerComparison: + """Compare CpuLoadBalancerSlsResource with CpuServerlessEndpoint for consistency.""" + + def test_input_only_alignment_with_cpu_serverless(self): + """Test _input_only aligns with CpuServerlessEndpoint for GPU fields.""" + lb = CpuLoadBalancerSlsResource( + name="lb", + imageName="test:latest", + ) + + serverless = CpuServerlessEndpoint( + name="serverless", + imageName="test:latest", + ) + + # Critical GPU fields should be in both _input_only sets + gpu_fields = { + "gpuCount", + "allowedCudaVersions", + "gpuIds", + "cudaVersions", + "gpus", + } + + for field in gpu_fields: + assert field in lb._input_only, f"{field} should be in LB _input_only" + assert field in serverless._input_only, ( + f"{field} should be in Serverless _input_only" + ) + + def test_gpu_field_sync_consistency(self): + """Test GPU field values match between LB and Serverless.""" + lb = CpuLoadBalancerSlsResource( + name="lb", + imageName="test:latest", + instanceIds=[CpuInstanceType.CPU3G_1_4], + ) + + serverless = CpuServerlessEndpoint( + name="serverless", + imageName="test:latest", + instanceIds=[CpuInstanceType.CPU3G_1_4], + ) + + # Both should have identical GPU field values + assert lb.gpuCount == serverless.gpuCount == 0 + assert lb.allowedCudaVersions == serverless.allowedCudaVersions == "" + assert lb.gpuIds == serverless.gpuIds == "" + + +class TestCpuLoadBalancerIntegration: + """Integration tests for CPU load balancer deployment payloads.""" + + def test_deployment_payload_structure_is_valid(self): + """Test deployment payload has correct structure for RunPod API.""" + lb = CpuLoadBalancerSlsResource( + name="prod-api", + imageName="myapp/api:v1", + instanceIds=[CpuInstanceType.CPU3G_1_4], + workersMin=1, + workersMax=5, + scalerValue=10, + ) + + payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") + + # Verify payload structure + required_fields = {"name", "type", "scalerType", "workersMin", "workersMax"} + for field in required_fields: + assert field in payload, f"Required field {field} not in payload" + + # Verify no GPU fields + gpu_fields = {"gpuCount", "allowedCudaVersions", "gpuIds"} + for field in gpu_fields: + assert field not in payload, f"GPU field {field} should not be in payload" + + def test_cpu_disk_sizing_respects_limits(self): + """Test that CPU load balancer doesn't raise disk sizing errors on creation.""" + # This test verifies that we can create a CPU LB without disk sizing errors + # The actual disk sizing is applied when needed via _apply_cpu_disk_sizing + lb = CpuLoadBalancerSlsResource( + name="test-cpu-lb", + imageName="test/image:latest", + instanceIds=[CpuInstanceType.CPU3G_1_4], + ) + + # Should have a template + assert lb.template is not None + assert lb.template.imageName == "test/image:latest" + + def test_cpu_load_balancer_with_env_vars(self): + """Test CPU load balancer with environment variables.""" + env = { + "FLASH_APP": "my_app", + "LOG_LEVEL": "DEBUG", + } + + lb = CpuLoadBalancerSlsResource( + name="test", + imageName="test/image:latest", + env=env, + ) + + assert lb.env == env + + def test_cpu_load_balancer_with_worker_config(self): + """Test CPU load balancer with worker scaling configuration.""" + lb = CpuLoadBalancerSlsResource( + name="test", + imageName="test/image:latest", + workersMin=1, + workersMax=5, + scalerValue=10, + ) + + assert lb.workersMin == 1 + assert lb.workersMax == 5 + assert lb.scalerValue == 10 diff --git a/tests/unit/resources/test_live_load_balancer.py b/tests/unit/resources/test_live_load_balancer.py new file mode 100644 index 00000000..43981292 --- /dev/null +++ b/tests/unit/resources/test_live_load_balancer.py @@ -0,0 +1,265 @@ +""" +Unit tests for LiveLoadBalancer class and template serialization. +""" + +import os + +import pytest + +from tetra_rp.core.resources.cpu import CpuInstanceType +from tetra_rp.core.resources.live_serverless import ( + CpuLiveLoadBalancer, + LiveLoadBalancer, +) +from tetra_rp.core.resources.load_balancer_sls_resource import LoadBalancerSlsResource + + +class TestLiveLoadBalancer: + """Test LiveLoadBalancer class behavior.""" + + def test_live_load_balancer_creation_with_local_tag(self, monkeypatch): + """Test LiveLoadBalancer creates with local image tag.""" + monkeypatch.setenv("TETRA_IMAGE_TAG", "local") + # Need to reload the module to pick up new env var + import importlib + import tetra_rp.core.resources.live_serverless as ls_module + + importlib.reload(ls_module) + + lb = ls_module.LiveLoadBalancer(name="test-lb") + assert lb.imageName == "runpod/tetra-rp-lb:local" + assert lb.template is not None + assert lb.template.imageName == "runpod/tetra-rp-lb:local" + + def test_live_load_balancer_default_image_tag(self): + """Test LiveLoadBalancer uses default image tag.""" + # Clear any custom tag + os.environ.pop("TETRA_IMAGE_TAG", None) + + lb = LiveLoadBalancer(name="test-lb") + + assert "runpod/tetra-rp-lb:" in lb.imageName + assert lb.template is not None + assert lb.template.imageName == lb.imageName + + def test_live_load_balancer_template_creation(self): + """Test LiveLoadBalancer creates proper template from imageName.""" + lb = LiveLoadBalancer(name="cpu_processor") + + # Should have a template created from imageName + assert lb.template is not None + assert lb.template.imageName == lb.imageName + # Template name uses resource IDs, not the original name + assert "LiveLoadBalancer" in lb.template.name + assert "PodTemplate" in lb.template.name + + def test_live_load_balancer_template_env_variables(self): + """Test LiveLoadBalancer template includes environment variables.""" + lb = LiveLoadBalancer( + name="test-lb", + env={"CUSTOM_VAR": "custom_value"}, + ) + + assert lb.template is not None + assert lb.template.env is not None + assert len(lb.template.env) > 0 + + # Check for custom env var + custom_vars = [kv for kv in lb.template.env if kv.key == "CUSTOM_VAR"] + assert len(custom_vars) == 1 + assert custom_vars[0].value == "custom_value" + + def test_live_load_balancer_payload_serialization(self): + """Test LiveLoadBalancer serializes correctly for GraphQL deployment.""" + lb = LiveLoadBalancer(name="data_processor") + + # Generate payload as would be sent to RunPod + payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") + + # Template must be in payload (not imageName since that's in _input_only) + assert "template" in payload + assert "imageName" not in payload + + # Template must have all required fields + template = payload["template"] + assert "imageName" in template + assert "name" in template + assert template["imageName"] == lb.imageName + + def test_live_load_balancer_type_is_lb(self): + """Test LiveLoadBalancer has type=LB.""" + lb = LiveLoadBalancer(name="test-lb") + + assert lb.type.value == "LB" + assert str(lb.type) == "ServerlessType.LB" + + def test_live_load_balancer_scaler_is_request_count(self): + """Test LiveLoadBalancer uses REQUEST_COUNT scaler.""" + lb = LiveLoadBalancer(name="test-lb") + + assert lb.scalerType.value == "REQUEST_COUNT" + + +class TestLoadBalancerSlsResourceTemplate: + """Test LoadBalancerSlsResource template handling.""" + + def test_load_balancer_sls_with_image_name(self): + """Test LoadBalancerSlsResource creates template from imageName.""" + lb = LoadBalancerSlsResource( + name="test-lb", + imageName="runpod/tetra-rp-lb:latest", + ) + + assert lb.template is not None + assert lb.template.imageName == "runpod/tetra-rp-lb:latest" + + def test_load_balancer_sls_requires_image_template_or_id(self): + """Test LoadBalancerSlsResource requires one of: imageName, template, templateId.""" + with pytest.raises( + ValueError, + match="Either imageName, template, or templateId must be provided", + ): + LoadBalancerSlsResource(name="test-lb") + + def test_load_balancer_sls_with_template_id(self): + """Test LoadBalancerSlsResource works with templateId.""" + lb = LoadBalancerSlsResource( + name="test-lb", + templateId="template-123", + ) + + assert lb.templateId == "template-123" + assert lb.template is None + + +class TestTemplateSerializationRoundtrip: + """Test that template serialization works correctly for GraphQL.""" + + def test_live_load_balancer_serialization_roundtrip(self): + """Test that LiveLoadBalancer can be serialized and contains template.""" + lb = LiveLoadBalancer( + name="test-service", + env={"API_KEY": "secret123"}, + ) + + # Simulate what gets sent to RunPod + payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") + + # Verify GraphQL payload has template + assert "template" in payload, "Template must be in GraphQL payload" + assert payload["template"]["imageName"] is not None + assert payload["template"]["name"] is not None + + # Verify imageName is NOT in payload (it's in _input_only) + assert "imageName" not in payload + + # Verify the template has the correct image + assert "tetra-rp-lb:" in payload["template"]["imageName"], ( + "Must have load-balancer image" + ) + + def test_template_env_serialization(self): + """Test template environment variables serialize correctly.""" + lb = LiveLoadBalancer( + name="test-lb", + env={"VAR1": "value1", "VAR2": "value2"}, + ) + + payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") + + template_env = payload["template"]["env"] + assert isinstance(template_env, list) + assert len(template_env) >= 2 + + # Check env vars are serialized as {key, value} objects + var_keys = {kv["key"] for kv in template_env} + assert "VAR1" in var_keys + assert "VAR2" in var_keys + + +class TestCpuLiveLoadBalancer: + """Test CpuLiveLoadBalancer class behavior.""" + + def test_cpu_live_load_balancer_creation_with_local_tag(self, monkeypatch): + """Test CpuLiveLoadBalancer creates with local image tag.""" + monkeypatch.setenv("TETRA_IMAGE_TAG", "local") + # Need to reload the module to pick up new env var + import importlib + + import tetra_rp.core.resources.live_serverless as ls_module + + importlib.reload(ls_module) + + lb = ls_module.CpuLiveLoadBalancer(name="test-lb") + assert lb.imageName == "runpod/tetra-rp-lb-cpu:local" + assert lb.template is not None + assert lb.template.imageName == "runpod/tetra-rp-lb-cpu:local" + + def test_cpu_live_load_balancer_default_image_tag(self): + """Test CpuLiveLoadBalancer uses default CPU LB image tag.""" + # Clear any custom tag + os.environ.pop("TETRA_IMAGE_TAG", None) + + lb = CpuLiveLoadBalancer(name="test-lb") + + assert "runpod/tetra-rp-lb-cpu:" in lb.imageName + assert lb.template is not None + assert lb.template.imageName == lb.imageName + + def test_cpu_live_load_balancer_defaults_to_cpu_any(self): + """Test CpuLiveLoadBalancer expands CPU_ANY to all available types.""" + lb = CpuLiveLoadBalancer(name="test-lb") + + # ANY should expand to all available CPU instance types + assert lb.instanceIds == CpuInstanceType.all() + assert len(lb.instanceIds) == 12 # 4 cpu3g + 4 cpu3c + 4 cpu5c + + def test_cpu_live_load_balancer_with_specific_cpu_instances(self): + """Test CpuLiveLoadBalancer with explicit CPU instances.""" + lb = CpuLiveLoadBalancer( + name="test-lb", + instanceIds=[CpuInstanceType.CPU3G_1_4], + ) + + assert lb.instanceIds == [CpuInstanceType.CPU3G_1_4] + + def test_cpu_live_load_balancer_type_is_lb(self): + """Test CpuLiveLoadBalancer has type=LB.""" + lb = CpuLiveLoadBalancer(name="test-lb") + + assert lb.type.value == "LB" + assert str(lb.type) == "ServerlessType.LB" + + def test_cpu_live_load_balancer_scaler_is_request_count(self): + """Test CpuLiveLoadBalancer uses REQUEST_COUNT scaler.""" + lb = CpuLiveLoadBalancer(name="test-lb") + + assert lb.scalerType.value == "REQUEST_COUNT" + + def test_cpu_live_load_balancer_payload_serialization(self): + """Test CpuLiveLoadBalancer serializes correctly for GraphQL deployment.""" + lb = CpuLiveLoadBalancer(name="data_processor") + + # Generate payload as would be sent to RunPod + payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") + + # Template must be in payload (not imageName since that's in _input_only) + assert "template" in payload + assert "imageName" not in payload + + # Template must have all required fields + template = payload["template"] + assert "imageName" in template + assert "name" in template + assert template["imageName"] == lb.imageName + + def test_cpu_live_load_balancer_excludes_gpu_fields(self): + """Test CpuLiveLoadBalancer excludes GPU fields from payload.""" + lb = CpuLiveLoadBalancer(name="test-lb") + + payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") + + # GPU-specific fields should not be in payload + assert "gpus" not in payload + assert "gpuIds" not in payload + assert "cudaVersions" not in payload diff --git a/tests/unit/test_load_balancer_sls_resource.py b/tests/unit/test_load_balancer_sls_resource.py new file mode 100644 index 00000000..d73f694b --- /dev/null +++ b/tests/unit/test_load_balancer_sls_resource.py @@ -0,0 +1,525 @@ +""" +Tests for LoadBalancerSlsResource provisioning and health checks. +""" + +import os + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from tetra_rp.core.resources import ( + LoadBalancerSlsResource, + ServerlessType, + ServerlessScalerType, +) +from tetra_rp.core.resources.serverless import ServerlessResource + +# Set a dummy API key for tests that create ResourceManager instances +os.environ.setdefault("RUNPOD_API_KEY", "test-key-for-unit-tests") + + +class TestLoadBalancerSlsResourceCreation: + """Test LoadBalancerSlsResource creation and validation.""" + + def test_create_with_defaults(self): + """Test creating LoadBalancerSlsResource with minimal config.""" + resource = LoadBalancerSlsResource( + name="test-endpoint", + imageName="test-image:latest", + ) + + # Note: name gets -fb suffix added by sync_input_fields due to flashboot=True + assert resource.name == "test-endpoint-fb" + assert resource.imageName == "test-image:latest" + assert resource.type == ServerlessType.LB + assert resource.scalerType == ServerlessScalerType.REQUEST_COUNT + + def test_type_always_lb(self): + """Test that type is always LB regardless of input.""" + # Try to set type to QB - should be overridden to LB + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + type=ServerlessType.QB, # This should be overridden + ) + + assert resource.type == ServerlessType.LB + + def test_scaler_type_defaults_to_request_count(self): + """Test that scaler type defaults to REQUEST_COUNT for LB.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + ) + + assert resource.scalerType == ServerlessScalerType.REQUEST_COUNT + + def test_validate_lb_configuration_rejects_queue_delay(self): + """Test that QUEUE_DELAY scaler is rejected for LB endpoints.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + scalerType=ServerlessScalerType.QUEUE_DELAY, + ) + + with pytest.raises(ValueError, match="requires REQUEST_COUNT scaler"): + resource._validate_lb_configuration() + + def test_with_custom_env_vars(self): + """Test creating LB resource with custom environment variables.""" + env = { + "FLASH_APP": "my_app", + "LOG_LEVEL": "DEBUG", + } + + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + env=env, + ) + + assert resource.env == env + + def test_with_worker_config(self): + """Test creating LB resource with worker scaling config.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + workersMin=1, + workersMax=5, + scalerValue=10, + ) + + assert resource.workersMin == 1 + assert resource.workersMax == 5 + assert resource.scalerValue == 10 + + def test_endpoint_url_format_for_load_balanced_endpoints(self): + """Test that endpoint_url uses load-balanced format, not v2 API format.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + id="6g2hfns3ar5pti", + ) + + # Load-balanced endpoints use: https://{id}.api.runpod.ai + # NOT: https://api.runpod.ai/v2/{id} + assert resource.endpoint_url == "https://6g2hfns3ar5pti.api.runpod.ai" + + def test_endpoint_url_raises_without_id(self): + """Test that endpoint_url raises error when endpoint ID not set.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + ) + + with pytest.raises(ValueError, match="Endpoint ID not set"): + _ = resource.endpoint_url + + +class TestLoadBalancerSlsResourceHealthCheck: + """Test health check functionality.""" + + @staticmethod + def _create_mock_client( + status_code: int = 200, error: Exception = None + ) -> MagicMock: + """Create properly configured async context manager mock client.""" + mock_response = AsyncMock() + mock_response.status_code = status_code + mock_client = MagicMock() + if error: + mock_client.get = AsyncMock(side_effect=error) + else: + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + return mock_client + + @pytest.mark.asyncio + async def test_check_ping_endpoint_success(self): + """Test successful ping endpoint check with ID set.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + id="test-endpoint-id", + ) + + mock_client = self._create_mock_client(200) + with ( + patch.object( + LoadBalancerSlsResource, + "endpoint_url", + new_callable=lambda: property(lambda self: "https://test-endpoint.com"), + ), + patch( + "tetra_rp.core.utils.http.httpx.AsyncClient", + return_value=mock_client, + ), + ): + result = await resource._check_ping_endpoint() + + assert result is True + + @pytest.mark.asyncio + async def test_check_ping_endpoint_initializing(self): + """Test ping endpoint returning 204 (initializing).""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + id="test-endpoint-id", + ) + + mock_client = self._create_mock_client(204) + with ( + patch.object( + LoadBalancerSlsResource, + "endpoint_url", + new_callable=lambda: property(lambda self: "https://test-endpoint.com"), + ), + patch( + "tetra_rp.core.utils.http.httpx.AsyncClient", + return_value=mock_client, + ), + ): + result = await resource._check_ping_endpoint() + + assert result is True + + @pytest.mark.asyncio + async def test_check_ping_endpoint_failure(self): + """Test ping endpoint returning unhealthy status.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + id="test-endpoint-id", + ) + + mock_client = self._create_mock_client(503) + with ( + patch.object( + LoadBalancerSlsResource, + "endpoint_url", + new_callable=lambda: property(lambda self: "https://test-endpoint.com"), + ), + patch( + "tetra_rp.core.utils.http.httpx.AsyncClient", + return_value=mock_client, + ), + ): + result = await resource._check_ping_endpoint() + + assert result is False + + @pytest.mark.asyncio + async def test_check_ping_endpoint_connection_error(self): + """Test ping endpoint with connection error.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + id="test-endpoint-id", + ) + + mock_client = self._create_mock_client( + error=ConnectionError("Connection refused") + ) + with ( + patch.object( + LoadBalancerSlsResource, + "endpoint_url", + new_callable=lambda: property(lambda self: "https://test-endpoint.com"), + ), + patch( + "tetra_rp.core.utils.http.httpx.AsyncClient", + return_value=mock_client, + ), + ): + result = await resource._check_ping_endpoint() + + assert result is False + + @pytest.mark.asyncio + async def test_check_ping_endpoint_no_id(self): + """Test ping check when endpoint ID is not set.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + # id not set + ) + + result = await resource._check_ping_endpoint() + assert result is False + + @pytest.mark.asyncio + async def test_wait_for_health_success(self): + """Test health check polling with successful response.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + id="test-endpoint-id", + ) + + with patch.object(resource, "_check_ping_endpoint") as mock_check: + mock_check.return_value = True + + result = await resource._wait_for_health(max_retries=3) + + assert result is True + mock_check.assert_called_once() + + @pytest.mark.asyncio + async def test_wait_for_health_retry_then_success(self): + """Test health check polling with retries before success.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + id="test-endpoint-id", + ) + + with patch.object(resource, "_check_ping_endpoint") as mock_check: + # Fail twice, then succeed + mock_check.side_effect = [False, False, True] + + result = await resource._wait_for_health(max_retries=5, retry_interval=0) + + assert result is True + assert mock_check.call_count == 3 + + @pytest.mark.asyncio + async def test_wait_for_health_timeout(self): + """Test health check polling timeout after max retries.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + id="test-endpoint-id", + ) + + with patch.object(resource, "_check_ping_endpoint") as mock_check: + mock_check.return_value = False + + result = await resource._wait_for_health(max_retries=3, retry_interval=0) + + assert result is False + assert mock_check.call_count == 3 + + @pytest.mark.asyncio + async def test_wait_for_health_no_id(self): + """Test health check when endpoint ID not set.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + # id not set + ) + + with pytest.raises(ValueError, match="Cannot wait for health"): + await resource._wait_for_health() + + @pytest.mark.asyncio + async def test_is_deployed_async_with_id(self): + """Test is_deployed_async returns True when healthy.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + id="test-endpoint-id", + ) + + with patch.object(resource, "_check_ping_endpoint") as mock_check: + mock_check.return_value = True + + result = await resource.is_deployed_async() + + assert result is True + + @pytest.mark.asyncio + async def test_is_deployed_async_without_id(self): + """Test is_deployed_async returns False when ID not set.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + ) + + result = await resource.is_deployed_async() + + assert result is False + + @pytest.mark.asyncio + async def test_is_deployed_async_unhealthy(self): + """Test is_deployed_async returns False when unhealthy.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + id="test-endpoint-id", + ) + + with patch.object(resource, "_check_ping_endpoint") as mock_check: + mock_check.return_value = False + + result = await resource.is_deployed_async() + + assert result is False + + +class TestLoadBalancerSlsResourceDeployment: + """Test deployment flow.""" + + @pytest.mark.asyncio + async def test_do_deploy_validates_configuration(self): + """Test that _do_deploy validates LB configuration.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + scalerType=ServerlessScalerType.QUEUE_DELAY, + ) + + with pytest.raises(ValueError, match="requires REQUEST_COUNT scaler"): + await resource._do_deploy() + + @pytest.mark.asyncio + async def test_do_deploy_already_deployed(self): + """Test _do_deploy skips deployment if already deployed.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + id="existing-id", + ) + + with patch.object(LoadBalancerSlsResource, "is_deployed") as mock_deployed: + mock_deployed.return_value = True + + result = await resource._do_deploy() + + assert result == resource + + @pytest.mark.asyncio + async def test_do_deploy_success(self): + """Test successful deployment with health check.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + ) + + mock_deployed = LoadBalancerSlsResource( + name="test", + imageName="image", + id="new-endpoint-id", + ) + + with ( + patch.object( + LoadBalancerSlsResource, "is_deployed", MagicMock(return_value=False) + ), + patch.object( + resource, "_wait_for_health", new_callable=AsyncMock, return_value=True + ) as mock_wait, + patch.object( + ServerlessResource, + "_do_deploy", + new_callable=AsyncMock, + return_value=mock_deployed, + ), + ): + result = await resource._do_deploy() + + assert result == mock_deployed + mock_wait.assert_called_once() + + @pytest.mark.asyncio + async def test_do_deploy_health_check_timeout(self): + """Test deployment fails if health check times out.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + ) + + mock_deployed = LoadBalancerSlsResource( + name="test", + imageName="image", + id="new-endpoint-id", + ) + + with ( + patch.object( + LoadBalancerSlsResource, "is_deployed", MagicMock(return_value=False) + ), + patch.object( + resource, "_wait_for_health", new_callable=AsyncMock, return_value=False + ), + patch.object( + ServerlessResource, + "_do_deploy", + new_callable=AsyncMock, + return_value=mock_deployed, + ), + ): + with pytest.raises(TimeoutError, match="failed to become healthy"): + await resource._do_deploy() + + @pytest.mark.asyncio + async def test_do_deploy_parent_deploy_failure(self): + """Test deployment handles parent deploy failure.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + ) + + with ( + patch.object( + LoadBalancerSlsResource, "is_deployed", MagicMock(return_value=False) + ), + patch.object( + ServerlessResource, + "_do_deploy", + new_callable=AsyncMock, + side_effect=ValueError("RunPod API error"), + ), + ): + with pytest.raises(ValueError, match="RunPod API error"): + await resource._do_deploy() + + +class TestLoadBalancerSlsResourceIntegration: + """Integration tests with ResourceManager.""" + + def test_resource_manager_integration(self): + """Test that LoadBalancerSlsResource can be created and used.""" + # Test that LoadBalancerSlsResource can be instantiated and used + resource = LoadBalancerSlsResource( + name="integration-test", + imageName="test-image:latest", + ) + + assert isinstance(resource, LoadBalancerSlsResource) + assert resource.type == ServerlessType.LB + + def test_is_deployed_sync(self): + """Test synchronous is_deployed method.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + id="test-id", + ) + + # Mock the endpoint property and its health method + mock_endpoint = MagicMock() + mock_endpoint.health.return_value = {"status": "healthy"} + + with patch.object( + LoadBalancerSlsResource, + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + result = resource.is_deployed() + + assert result is True + mock_endpoint.health.assert_called_once() + + def test_is_deployed_sync_no_id(self): + """Test is_deployed returns False when no ID.""" + resource = LoadBalancerSlsResource( + name="test", + imageName="image", + ) + + result = resource.is_deployed() + + assert result is False diff --git a/tests/unit/test_load_balancer_sls_stub.py b/tests/unit/test_load_balancer_sls_stub.py new file mode 100644 index 00000000..c5adcbf6 --- /dev/null +++ b/tests/unit/test_load_balancer_sls_stub.py @@ -0,0 +1,436 @@ +"""Unit tests for LoadBalancerSlsStub functionality.""" + +import base64 +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +import cloudpickle + +from tetra_rp import LoadBalancerSlsResource +from tetra_rp.stubs.load_balancer_sls import LoadBalancerSlsStub + + +# Create test resources +test_lb_resource = LoadBalancerSlsResource( + name="test-lb", + imageName="test:latest", +) + + +class TestLoadBalancerSlsStubPrepareRequest: + """Test suite for _prepare_request method.""" + + def test_prepare_request_with_no_args(self): + """Test request preparation with no arguments.""" + stub = LoadBalancerSlsStub(test_lb_resource) + + def test_func(): + return "result" + + request = stub._prepare_request(test_func, None, None, True) + + assert request["function_name"] == "test_func" + assert "def test_func" in request["function_code"] + assert request["dependencies"] == [] + assert request["system_dependencies"] == [] + assert request["accelerate_downloads"] is True + assert "args" not in request or request["args"] == [] + assert "kwargs" not in request or request["kwargs"] == {} + + def test_prepare_request_with_args(self): + """Test request preparation with positional arguments.""" + stub = LoadBalancerSlsStub(test_lb_resource) + + def add(x, y): + return x + y + + arg1 = 5 + arg2 = 3 + request = stub._prepare_request(add, None, None, True, arg1, arg2) + + assert request["function_name"] == "add" + assert len(request["args"]) == 2 + + # Verify args are properly serialized + decoded_arg1 = cloudpickle.loads(base64.b64decode(request["args"][0])) + decoded_arg2 = cloudpickle.loads(base64.b64decode(request["args"][1])) + assert decoded_arg1 == 5 + assert decoded_arg2 == 3 + + def test_prepare_request_with_kwargs(self): + """Test request preparation with keyword arguments.""" + stub = LoadBalancerSlsStub(test_lb_resource) + + def greet(name, greeting="Hello"): + return f"{greeting}, {name}!" + + request = stub._prepare_request( + greet, None, None, True, name="Alice", greeting="Hi" + ) + + assert "kwargs" in request + assert len(request["kwargs"]) == 2 + + # Verify kwargs are properly serialized + decoded_name = cloudpickle.loads(base64.b64decode(request["kwargs"]["name"])) + decoded_greeting = cloudpickle.loads( + base64.b64decode(request["kwargs"]["greeting"]) + ) + assert decoded_name == "Alice" + assert decoded_greeting == "Hi" + + def test_prepare_request_with_dependencies(self): + """Test request preparation includes dependencies.""" + stub = LoadBalancerSlsStub(test_lb_resource) + + def test_func(): + return "result" + + dependencies = ["requests", "numpy"] + system_deps = ["git"] + + request = stub._prepare_request(test_func, dependencies, system_deps, True) + + assert request["dependencies"] == dependencies + assert request["system_dependencies"] == system_deps + + +class TestLoadBalancerSlsStubHandleResponse: + """Test suite for _handle_response method.""" + + def test_handle_response_success(self): + """Test successful response handling.""" + stub = LoadBalancerSlsStub(test_lb_resource) + + result_value = {"status": "ok", "value": 42} + result_b64 = base64.b64encode(cloudpickle.dumps(result_value)).decode("utf-8") + + response = {"success": True, "result": result_b64} + + result = stub._handle_response(response) + + assert result == result_value + + def test_handle_response_error(self): + """Test error response handling.""" + stub = LoadBalancerSlsStub(test_lb_resource) + + response = {"success": False, "error": "Function execution failed"} + + with pytest.raises(Exception, match="Remote execution failed"): + stub._handle_response(response) + + def test_handle_response_invalid_type(self): + """Test handling of invalid response type.""" + stub = LoadBalancerSlsStub(test_lb_resource) + + with pytest.raises(ValueError, match="Invalid response type"): + stub._handle_response("not a dict") + + def test_handle_response_missing_result(self): + """Test handling of success response without result.""" + stub = LoadBalancerSlsStub(test_lb_resource) + + response = {"success": True, "result": None} + + with pytest.raises( + ValueError, match="Response marked success but result is None" + ): + stub._handle_response(response) + + def test_handle_response_invalid_base64(self): + """Test handling of invalid base64 in result.""" + stub = LoadBalancerSlsStub(test_lb_resource) + + response = {"success": True, "result": "not_valid_base64!!!"} + + with pytest.raises(ValueError, match="Failed to deserialize result"): + stub._handle_response(response) + + +class TestLoadBalancerSlsStubExecuteFunction: + """Test suite for _execute_function method.""" + + @pytest.mark.asyncio + async def test_execute_function_no_endpoint_url(self): + """Test error when endpoint_url is not available.""" + mock_resource = MagicMock() + mock_resource.endpoint_url = None + stub = LoadBalancerSlsStub(mock_resource) + + request = { + "function_name": "test_func", + "function_code": "def test_func(): pass", + } + + with pytest.raises(ValueError, match="Endpoint URL not available"): + await stub._execute_function(request) + + @pytest.mark.asyncio + async def test_execute_function_timeout(self): + """Test timeout error handling.""" + mock_resource = MagicMock() + mock_resource.endpoint_url = "http://localhost:8000" + stub = LoadBalancerSlsStub(mock_resource) + + request = { + "function_name": "test_func", + "function_code": "def test_func(): pass", + } + + import httpx + + with patch("tetra_rp.stubs.load_balancer_sls.httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.post = AsyncMock( + side_effect=httpx.TimeoutException("Timeout") + ) + + with pytest.raises(TimeoutError, match="Execution timeout"): + await stub._execute_function(request) + + @pytest.mark.asyncio + async def test_execute_function_http_error(self): + """Test HTTP error handling.""" + mock_resource = MagicMock() + mock_resource.endpoint_url = "http://localhost:8000" + mock_resource.name = "test-lb" + stub = LoadBalancerSlsStub(mock_resource) + + request = { + "function_name": "test_func", + "function_code": "def test_func(): pass", + } + + import httpx + + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal server error" + + with patch("tetra_rp.stubs.load_balancer_sls.httpx.AsyncClient") as mock_client: + error = httpx.HTTPStatusError( + "Error", request=MagicMock(), response=mock_response + ) + mock_client.return_value.__aenter__.return_value.post = AsyncMock( + side_effect=error + ) + + with pytest.raises(RuntimeError, match="HTTP error from endpoint"): + await stub._execute_function(request) + + +class TestLoadBalancerSlsStubCall: + """Test suite for __call__ method.""" + + @pytest.mark.asyncio + async def test_call_success(self): + """Test successful stub execution.""" + mock_resource = MagicMock() + stub = LoadBalancerSlsStub(mock_resource) + + def add(x, y): + return x + y + + with patch.object(stub, "_execute_function") as mock_execute: + result_b64 = base64.b64encode(cloudpickle.dumps(8)).decode("utf-8") + mock_execute.return_value = {"success": True, "result": result_b64} + + result = await stub(add, None, None, True, 5, 3) + + assert result == 8 + mock_execute.assert_called_once() + + @pytest.mark.asyncio + async def test_call_with_dependencies(self): + """Test stub execution with dependencies.""" + mock_resource = MagicMock() + stub = LoadBalancerSlsStub(mock_resource) + + def use_requests(): + return "success" + + deps = ["requests"] + + with patch.object(stub, "_execute_function") as mock_execute: + result_b64 = base64.b64encode(cloudpickle.dumps("success")).decode("utf-8") + mock_execute.return_value = {"success": True, "result": result_b64} + + result = await stub(use_requests, deps, None, True) + + assert result == "success" + # Verify dependencies were included in request + call_args = mock_execute.call_args + request = call_args[0][0] + assert request["dependencies"] == deps + + +class TestLoadBalancerSlsStubRouting: + """Test suite for routing detection between /execute and user routes.""" + + def test_should_use_execute_for_live_load_balancer(self): + """Test that LiveLoadBalancer always uses /execute endpoint.""" + from tetra_rp import LiveLoadBalancer + from tetra_rp import remote + + lb = LiveLoadBalancer(name="test-live") + stub = LoadBalancerSlsStub(lb) + + @remote(lb, method="POST", path="/api/test") + def test_func(): + pass + + assert stub._should_use_execute_endpoint(test_func) is True + + def test_should_use_user_route_for_deployed_lb(self): + """Test that deployed LoadBalancerSlsResource uses user-defined route.""" + from tetra_rp import remote + + lb = LoadBalancerSlsResource(name="test-deployed", imageName="test:latest") + stub = LoadBalancerSlsStub(lb) + + @remote(lb, method="POST", path="/api/test") + def test_func(): + pass + + assert stub._should_use_execute_endpoint(test_func) is False + + def test_should_fallback_to_execute_without_routing_metadata(self): + """Test fallback to /execute when routing metadata is missing.""" + lb = LoadBalancerSlsResource(name="test", imageName="test:latest") + stub = LoadBalancerSlsStub(lb) + + def func_without_metadata(): + pass + + assert stub._should_use_execute_endpoint(func_without_metadata) is True + + def test_should_fallback_to_execute_with_incomplete_metadata(self): + """Test fallback to /execute when routing metadata is incomplete.""" + lb = LoadBalancerSlsResource(name="test", imageName="test:latest") + stub = LoadBalancerSlsStub(lb) + + def func_with_incomplete_metadata(): + pass + + # Attach incomplete metadata + func_with_incomplete_metadata.__remote_config__ = {"method": "POST"} + + assert stub._should_use_execute_endpoint(func_with_incomplete_metadata) is True + + @pytest.mark.asyncio + async def test_execute_via_user_route_success(self): + """Test successful execution via user-defined route.""" + mock_resource = MagicMock() + mock_resource.endpoint_url = "http://localhost:8000" + mock_resource.name = "test-lb" + stub = LoadBalancerSlsStub(mock_resource) + + def add(x, y): + return x + y + + mock_response = MagicMock() + mock_response.json.return_value = {"result": 8} + + with patch("tetra_rp.stubs.load_balancer_sls.httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.request = AsyncMock( + return_value=mock_response + ) + + result = await stub._execute_via_user_route(add, "POST", "/api/add", 5, 3) + + assert result == {"result": 8} + # Verify correct HTTP method and URL + mock_client.return_value.__aenter__.return_value.request.assert_called_once() + call_args = ( + mock_client.return_value.__aenter__.return_value.request.call_args + ) + assert call_args[0][0] == "POST" + assert call_args[0][1] == "http://localhost:8000/api/add" + # Verify correct JSON body with mapped parameters + assert call_args[1]["json"] == {"x": 5, "y": 3} + + @pytest.mark.asyncio + async def test_execute_via_user_route_with_kwargs(self): + """Test user route execution with keyword arguments.""" + mock_resource = MagicMock() + mock_resource.endpoint_url = "http://localhost:8000" + mock_resource.name = "test-lb" + stub = LoadBalancerSlsStub(mock_resource) + + def greet(name, greeting="Hello"): + return f"{greeting}, {name}!" + + mock_response = MagicMock() + mock_response.json.return_value = "Hi, Alice!" + + with patch("tetra_rp.stubs.load_balancer_sls.httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.request = AsyncMock( + return_value=mock_response + ) + + result = await stub._execute_via_user_route( + greet, "POST", "/api/greet", "Alice", greeting="Hi" + ) + + assert result == "Hi, Alice!" + # Verify JSON body has both positional arg and kwargs + call_args = ( + mock_client.return_value.__aenter__.return_value.request.call_args + ) + assert call_args[1]["json"] == {"name": "Alice", "greeting": "Hi"} + + @pytest.mark.asyncio + async def test_call_routes_to_user_path_for_deployed_endpoint(self): + """Test that __call__ routes to user path for deployed endpoints.""" + mock_resource = MagicMock() + mock_resource.endpoint_url = "http://localhost:8000" + mock_resource.name = "test-lb" + stub = LoadBalancerSlsStub(mock_resource) + + @patch.object(stub, "_should_use_execute_endpoint") + @patch.object(stub, "_execute_via_user_route") + async def run_test(mock_user_route, mock_detect): + mock_detect.return_value = False + mock_user_route.return_value = {"result": 42} + + def test_func(x): + return x + + test_func.__remote_config__ = { + "method": "POST", + "path": "/api/test", + "resource_config": mock_resource, + } + + result = await stub(test_func, None, None, True, 42) + + # Should route to _execute_via_user_route, not _execute_function + mock_user_route.assert_called_once() + assert result == {"result": 42} + + await run_test() + + @pytest.mark.asyncio + async def test_call_routes_to_execute_for_live_endpoint(self): + """Test that __call__ routes to /execute for LiveLoadBalancer.""" + mock_resource = MagicMock() + stub = LoadBalancerSlsStub(mock_resource) + + @patch.object(stub, "_should_use_execute_endpoint") + @patch.object(stub, "_execute_function") + @patch.object(stub, "_handle_response") + async def run_test(mock_handle, mock_execute, mock_detect): + mock_detect.return_value = True + mock_execute.return_value = {"success": True, "result": "test"} + mock_handle.return_value = "handled" + + def test_func(): + pass + + result = await stub(test_func, None, None, True) + + # Should route to _execute_function, not _execute_via_user_route + mock_execute.assert_called_once() + assert result == "handled" + + await run_test()