diff --git a/docs/Cross_Endpoint_Routing.md b/docs/Cross_Endpoint_Routing.md
new file mode 100644
index 00000000..6d059636
--- /dev/null
+++ b/docs/Cross_Endpoint_Routing.md
@@ -0,0 +1,943 @@
+# Cross-Endpoint Routing for Serverless Functions
+
+## Overview
+
+Cross-endpoint routing enables serverless functions to seamlessly call functions deployed on different endpoints. Functions can execute locally or remotely based on service discovery configuration, allowing developers to build distributed applications without manual routing logic.
+
+## Problem Statement
+
+Previously, serverless functions were isolated to their deployment endpoint. Building distributed applications required:
+- Manual HTTP calls to other endpoints with serialization boilerplate
+- No unified function invocation pattern across endpoints
+- Difficulty discovering which functions exist on which endpoints
+- Complex error handling for remote failures
+- No automatic argument serialization/deserialization
+
+Cross-endpoint routing solves these problems by providing transparent function routing with manifest-based service discovery.
+
+## User Guide
+
+### Quick Start
+
+Cross-endpoint routing works through a manifest that maps functions to resource configurations. Functions are routed automatically based on this manifest.
+
+#### 1. Create a Manifest
+
+Create a `flash_manifest.json` in your project root:
+
+```json
+{
+ "version": "1.0",
+ "project_name": "my-project",
+ "function_registry": {
+ "process_image": "image-processor",
+ "generate_report": "report-generator"
+ },
+ "resources": {
+ "image-processor": {
+ "resource_type": "ServerlessEndpoint",
+ "handler_file": "handlers/image.py",
+ "functions": [
+ {"name": "process_image", "module": "processors.image", "is_async": true}
+ ]
+ },
+ "report-generator": {
+ "resource_type": "ServerlessEndpoint",
+ "handler_file": "handlers/report.py",
+ "functions": [
+ {"name": "generate_report", "module": "processors.report", "is_async": true}
+ ]
+ }
+ }
+}
+```
+
+The manifest structure:
+- **`function_registry`**: Maps function names to resource configuration names
+- **`resources`**: Defines resource configurations with their handler files and function details
+- **Functions not in `function_registry`** execute locally on the current endpoint
+
+#### 2. Set Environment Variables
+
+Configure the mothership directory URL (required for remote routing):
+
+```bash
+# Required for cross-endpoint routing to work
+export FLASH_MOTHERSHIP_URL=https://mothership.example.com
+
+# Optional: Identifies the current endpoint (useful for distributed tracing)
+export RUNPOD_ENDPOINT_ID=gpu_config
+```
+
+Note: Without `FLASH_MOTHERSHIP_URL`, all functions execute locally. The system gracefully falls back to local execution.
+
+#### 3. Define Functions
+
+Define functions normally. The routing system decides execution location:
+
+```python
+from tetra_rp import stub
+
+@stub.function()
+async def process_image(image_path: str) -> dict:
+ """Process an image - may execute locally or remotely."""
+ # This function might route to 'image-processor' endpoint
+ # based on manifest configuration
+ return {"processed": True}
+
+@stub.function()
+async def local_only_function(data: str) -> str:
+ """Always executes locally (not in manifest)."""
+ return f"Processed: {data}"
+
+@stub.function()
+async def generate_report(data: list) -> bytes:
+ """May route to 'report-generator' endpoint."""
+ return b"report data"
+```
+
+#### 4. Call Functions Normally
+
+The routing system handles execution location transparently:
+
+```python
+# Local execution (not in manifest)
+result = await local_only_function("hello")
+
+# Remote or local execution (based on manifest)
+result = await process_image("path/to/image.jpg")
+```
+
+### Configuration
+
+#### Manifest File
+
+The manifest file (`flash_manifest.json`) defines function routing and resource configurations:
+
+```json
+{
+ "version": "1.0",
+ "project_name": "my-project",
+ "function_registry": {
+ "function_name": "resource_config_name",
+ "another_function": "another_resource"
+ },
+ "resources": {
+ "resource_config_name": {
+ "resource_type": "ServerlessEndpoint",
+ "handler_file": "handlers/module.py",
+ "functions": [
+ {"name": "function_name", "module": "path.to.module", "is_async": true}
+ ]
+ }
+ }
+}
+```
+
+**Key Points**:
+- **`function_registry`**: Maps function names to resource configuration names
+- **`resources`**: Defines deployment configuration for each resource
+- **Functions not listed** in `function_registry` execute locally
+- Missing manifest: logs warning, all functions execute locally
+
+**Manifest Location** (in order of precedence):
+1. Path specified in `FLASH_MANIFEST_PATH` env var
+2. `flash_manifest.json` in project root
+3. Auto-detected in package directory
+
+#### Environment Variables
+
+| Variable | Required | Purpose |
+|----------|----------|---------|
+| `FLASH_MOTHERSHIP_URL` | Yes* | URL of mothership directory service |
+| `RUNPOD_ENDPOINT_ID` | No | Current endpoint ID (for tracing) |
+| `FLASH_MANIFEST_PATH` | No | Explicit path to manifest file |
+
+*Required for remote routing; without it, all functions execute locally
+
+### Usage Patterns
+
+#### Pattern 1: Microservice Architecture
+
+Split functionality across endpoints using manifest:
+
+**Manifest (flash_manifest.json)**:
+```json
+{
+ "version": "1.0",
+ "function_registry": {
+ "resize_image": "image-processor",
+ "generate_metrics": "report-generator",
+ "workflow": "orchestrator"
+ },
+ "resources": {
+ "image-processor": {
+ "resource_type": "ServerlessEndpoint",
+ "functions": [
+ {"name": "resize_image", "module": "image_module", "is_async": true}
+ ]
+ },
+ "report-generator": {
+ "resource_type": "ServerlessEndpoint",
+ "functions": [
+ {"name": "generate_metrics", "module": "metrics_module", "is_async": true}
+ ]
+ },
+ "orchestrator": {
+ "resource_type": "ServerlessEndpoint",
+ "functions": [
+ {"name": "workflow", "module": "orchestration", "is_async": true}
+ ]
+ }
+ }
+}
+```
+
+**Functions**:
+```python
+@stub.function()
+async def resize_image(path: str, size: int) -> str:
+ return process_image(path, size)
+
+@stub.function()
+async def generate_metrics(data: list) -> dict:
+ return create_metrics(data)
+
+@stub.function()
+async def workflow():
+ # Transparently calls across endpoints
+ image = await resize_image("input.jpg", 512)
+ metrics = await generate_metrics([1, 2, 3])
+ return {"image": image, "metrics": metrics}
+```
+
+#### Pattern 2: Mixed Local and Remote
+
+Configure some functions for remote execution, others local:
+
+**Manifest**:
+```json
+{
+ "version": "1.0",
+ "function_registry": {
+ "heavy_computation": "gpu-cluster"
+ },
+ "resources": {
+ "gpu-cluster": {
+ "resource_type": "ServerlessEndpoint",
+ "functions": [
+ {"name": "heavy_computation", "module": "gpu_module", "is_async": true}
+ ]
+ }
+ }
+}
+```
+
+**Functions**:
+```python
+@stub.function()
+async def heavy_computation(data: bytes) -> bytes:
+ # Routes to GPU cluster (in function_registry)
+ return gpu_process(data)
+
+@stub.function()
+async def light_computation(value: int) -> int:
+ # Always local - not in function_registry
+ return value * 2
+```
+
+#### Pattern 3: Fallback to Local
+
+Functions gracefully fall back to local execution if routing fails:
+
+```python
+@stub.function()
+async def critical_service(request: dict) -> dict:
+ # Routes to critical-endpoint if:
+ # - In function_registry
+ # - Directory available
+ # Otherwise executes locally
+ return handle_critical(request)
+
+@stub.function()
+async def helper_function(x: int) -> int:
+ # Always local - not in manifest
+ return x + 1
+```
+
+### Error Handling
+
+#### Common Issues
+
+**Directory Unavailable**
+
+If `FLASH_MOTHERSHIP_URL` is not set or unreachable:
+```
+WARNING: FLASH_MOTHERSHIP_URL not set, directory unavailable
+```
+
+Functions default to local execution. Set the environment variable to enable routing.
+
+**Manifest Not Found**
+
+If manifest file cannot be found:
+```
+WARNING: flash_manifest.json not found. Cross-endpoint routing disabled.
+```
+
+The system logs a warning but continues—all functions execute locally. Create `flash_manifest.json` in your project root or set `FLASH_MANIFEST_PATH` to enable routing.
+
+**Serialization Error**
+
+If function arguments cannot be serialized:
+```
+SerializationError: Failed to serialize arguments: [error details]
+```
+
+Ensure arguments are picklable (cloudpickle-compatible). Avoid lambdas and local functions.
+
+**Remote Execution Failed**
+
+If remote endpoint fails:
+```
+RemoteExecutionError: Remote function execution failed: [error details]
+```
+
+Check remote endpoint logs and ensure it's deployed and healthy.
+
+### Serialization
+
+Arguments are serialized using cloudpickle and transmitted as base64-encoded HTTP payloads:
+
+```python
+# Simple types - always work
+await process_image("path.jpg") # str
+await generate_report([1, 2, 3]) # list
+await compute(data=np.array([1, 2, 3])) # numpy array
+
+# Complex types - work if picklable
+class DataClass:
+ def __init__(self, value):
+ self.value = value
+
+await process_data(DataClass(42)) # Custom class
+
+# Won't work
+async def with_lambda():
+ f = lambda x: x + 1 # Can't serialize lambdas
+ await remote_function(f) # Will fail
+
+async def with_local_func():
+ def local_helper():
+ return 42
+ await remote_function(local_helper) # Will fail
+```
+
+## Contributor Guide
+
+### Architecture Overview
+
+```mermaid
+graph TD
+ A["Function Call"] -->|"intercepts stub layer"| B["ProductionWrapper"]
+
+ B -->|"load service configuration"| C["ServiceRegistry"]
+ C -->|"if not cached"| D["DirectoryClient"]
+ D -->|"query mothership API"| E["Directory
Endpoint URLs"]
+ E -->|"cache result
TTL 300s"| C
+
+ C -->|"lookup in manifest
flash_manifest.json"| F{"Routing
Decision"}
+
+ F -->|"Local
resource=None"| G["Execute Locally
original stub"]
+ F -->|"Not in Manifest"| G
+ F -->|"Remote
resource found"| H["Remote Execution"]
+
+ H --> I["Serialize Arguments
cloudpickle → base64"]
+ I --> J["HTTP POST
to Remote Endpoint"]
+ J --> K["Remote Function
Executes"]
+ K --> L["Return Response
base64 → cloudpickle"]
+ L --> M["Deserialized Result"]
+
+ N["Error Handling:
- RemoteExecutionError
- SerializationError
- DirectoryUnavailableError"] -.-> H
+ N -.-> I
+ N -.-> J
+
+ G --> O["Result Returned
to Caller"]
+ M --> O
+
+ style A fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff
+ style B fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff
+ style C fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff
+ style D fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff
+ style F fill:#d32f2f,stroke:#b71c1c,stroke-width:3px,color:#fff
+ style G fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff
+ style H fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff
+ style I fill:#1565c0,stroke:#0d47a1,stroke-width:2px,color:#fff
+ style J fill:#1565c0,stroke:#0d47a1,stroke-width:2px,color:#fff
+ style K fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff
+ style L fill:#1565c0,stroke:#0d47a1,stroke-width:2px,color:#fff
+ style M fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff
+ style N fill:#f57c00,stroke:#e65100,stroke-width:3px,color:#fff
+ style O fill:#1976d2,stroke:#0d47a1,stroke-width:3px,color:#fff
+ style E fill:#0d7f1f,stroke:#0d4f1f,stroke-width:3px,color:#fff
+```
+
+### Core Components
+
+#### 1. ProductionWrapper
+
+**Location**: `src/tetra_rp/runtime/production_wrapper.py`
+
+Intercepts function calls at the stub layer and routes to local or remote execution:
+
+```python
+class ProductionWrapper:
+ """Wrapper that routes function execution between endpoints."""
+
+ async def wrap_function_execution(
+ self,
+ original_stub_func: Callable,
+ func: Callable,
+ dependencies: Optional[list],
+ system_dependencies: Optional[list],
+ accelerate_downloads: bool,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Any:
+ """Route function execution to local or remote endpoint."""
+ # 1. Load directory (if needed)
+ await self.service_registry._ensure_directory_loaded()
+
+ # 2. Look up function in manifest
+ resource = self.service_registry.get_resource_for_function(func.__name__)
+
+ # 3. Execute locally if not in manifest or resource is None
+ if resource is None:
+ return await original_stub_func(
+ func,
+ dependencies,
+ system_dependencies,
+ accelerate_downloads,
+ *args,
+ **kwargs
+ )
+
+ # 4. Execute remotely
+ return await self._execute_remote(resource, func.__name__, args, kwargs)
+```
+
+**Routing Logic**:
+- **Local**: Function not in manifest or resource=None
+- **Remote**: Function in manifest with resource configuration
+
+**Argument Serialization**:
+```python
+# Serialize
+serialized = base64.b64encode(cloudpickle.dumps(args))
+
+# Deserialize
+args = cloudpickle.loads(base64.b64decode(serialized))
+```
+
+#### 2. ServiceRegistry
+
+**Location**: `src/tetra_rp/runtime/service_registry.py`
+
+Manages service discovery and manifest loading:
+
+```python
+class ServiceRegistry:
+ """Service discovery and routing for cross-endpoint function calls."""
+
+ def __init__(self, manifest_path: Optional[Path] = None):
+ """Initialize with manifest and optional directory client."""
+ self._load_manifest(manifest_path)
+ self._directory_client = DirectoryClient(...)
+ self._directory = {} # Cached endpoint URLs
+ self._directory_lock = asyncio.Lock()
+
+ def get_resource_for_function(self, func_name: str) -> Optional[ServerlessResource]:
+ """Get resource config for function from manifest."""
+ # Returns None if:
+ # - Function not in manifest
+ # - Explicitly set to null in manifest
+
+ # Returns ServerlessResource if mapped in manifest
+ config = self._manifest["functions"].get(func_name)
+ return self._resolve_resource(config)
+
+ async def _ensure_directory_loaded(self) -> None:
+ """Load directory from mothership with caching (TTL 300s)."""
+ if self._is_directory_fresh():
+ return
+
+ async with self._directory_lock:
+ self._directory = await self._directory_client.get_directory()
+ self._directory_loaded_at = time.time()
+```
+
+**Manifest Format**:
+```json
+{
+ "version": "1.0",
+ "project_name": "project_name",
+ "function_registry": {
+ "function_name": "resource_config_name",
+ "local_function": null
+ },
+ "resources": {
+ "resource_config_name": {
+ "resource_type": "ServerlessEndpoint",
+ "handler_file": "handler.py",
+ "functions": [...]
+ }
+ }
+}
+```
+
+**Key fields**:
+- `function_registry`: Maps function names to resource config names (null = local)
+- `resources`: Defines resource configurations and their handler details
+
+**Directory Cache**:
+- TTL: 300 seconds (configurable via `DEFAULT_CACHE_TTL`)
+- Thread-safe with `asyncio.Lock()`
+- Graceful fallback if directory unavailable
+
+#### 3. DirectoryClient
+
+**Location**: `src/tetra_rp/runtime/directory_client.py`
+
+HTTP client for mothership directory service:
+
+```python
+class DirectoryClient:
+ """HTTP client for querying mothership directory.
+
+ The directory maps resource_config names to their endpoint URLs.
+ Example: {"gpu_config": "https://api.runpod.io/v2/abc123"}
+ """
+
+ async def get_directory(self) -> Dict[str, str]:
+ """Fetch endpoint directory from mothership.
+
+ Returns:
+ Dictionary mapping resource_config_name → endpoint_url.
+ Example: {"gpu_config": "https://api.runpod.io/v2/abc123"}
+
+ Raises:
+ DirectoryUnavailableError: If directory service unavailable after retries.
+ """
+ # Queries {mothership_url}/directory endpoint with retry logic
+```
+
+**Configuration**:
+- Mothership URL from `FLASH_MOTHERSHIP_URL` env var
+- HTTP timeout: 10 seconds (via `DEFAULT_REQUEST_TIMEOUT`)
+- Retry logic: Exponential backoff with `DEFAULT_MAX_RETRIES` attempts (default: 3)
+- Uses `httpx` library for async HTTP requests
+- Raises `ImportError` if httpx not installed (with helpful message)
+
+#### 4. Exception Hierarchy
+
+**Location**: `src/tetra_rp/runtime/exceptions.py`
+
+Custom exceptions for cross-endpoint routing:
+
+```python
+class FlashRuntimeError(Exception):
+ """Base exception for runtime errors in cross-endpoint execution."""
+ pass
+
+class RemoteExecutionError(FlashRuntimeError):
+ """Raised when remote function execution fails."""
+ pass
+
+class SerializationError(FlashRuntimeError):
+ """Raised when serialization or deserialization of arguments fails."""
+ pass
+
+class ManifestError(FlashRuntimeError):
+ """Raised when manifest is invalid, missing, or has unexpected structure."""
+ pass
+
+class DirectoryUnavailableError(FlashRuntimeError):
+ """Raised when directory service is unavailable."""
+ pass
+```
+
+**Usage**:
+```python
+try:
+ result = await remote_execution()
+except RemoteExecutionError as e:
+ logger.error(f"Remote execution failed: {e}")
+except SerializationError as e:
+ logger.error(f"Serialization failed: {e}")
+except ManifestError as e:
+ logger.error(f"Manifest configuration error: {e}")
+except DirectoryUnavailableError as e:
+ logger.warning(f"Directory unavailable, using fallback")
+```
+
+### Integration Points
+
+#### Stub Layer Integration
+
+ProductionWrapper integrates with the stub execution layer in `src/tetra_rp/stubs/registry.py`:
+
+```python
+# Before: Direct stub execution
+result = await original_stub_func(func, deps, ...)
+
+# After: Routed execution
+production_wrapper = ProductionWrapper(service_registry)
+result = await production_wrapper.wrap_function_execution(
+ original_stub_func,
+ func,
+ dependencies,
+ system_dependencies,
+ accelerate_downloads,
+ *args,
+ **kwargs
+)
+```
+
+The wrapper is transparent - function signatures and return types are unchanged.
+
+#### Resource Manager
+
+Functions retrieve remote endpoint info from ResourceManager:
+
+```python
+# ServiceRegistry uses ResourceManager to find endpoint URLs
+resource_manager = ResourceManager()
+endpoint = resource_manager.get_resource_for_function("function_name")
+endpoint_url = endpoint.url # e.g., "https://api.runpod.io/v1/abc123"
+```
+
+### Configuration
+
+#### Runtime Config
+
+**Location**: `src/tetra_rp/runtime/config.py`
+
+Centralized configuration constants:
+
+```python
+DEFAULT_CACHE_TTL = 300 # seconds
+DEFAULT_HTTP_TIMEOUT = 10 # seconds
+MAX_RETRY_ATTEMPTS = 3
+RETRY_BACKOFF_FACTOR = 2.0
+```
+
+Add new configuration by:
+1. Defining constant in `config.py`
+2. Importing in component that uses it
+3. Making configurable via environment variable if appropriate
+
+### Data Flow
+
+#### Local Execution Flow
+
+```
+Function Call
+ ↓
+ProductionWrapper.wrap_function_execution()
+ ↓
+ServiceRegistry.get_resource_for_function()
+ ↓
+Manifest Lookup (resource = None)
+ ↓
+Local Execution (original_stub_func)
+ ↓
+Result
+```
+
+#### Remote Execution Flow
+
+```
+Function Call
+ ↓
+ProductionWrapper.wrap_function_execution()
+ ↓
+ServiceRegistry.get_resource_for_function()
+ ↓
+Manifest Lookup (resource found)
+ ↓
+Ensure Directory Loaded
+ ↓
+DirectoryClient.get_endpoints()
+ ↓
+Get Remote Endpoint URL
+ ↓
+Serialize Arguments (cloudpickle → base64)
+ ↓
+HTTP POST to Remote Endpoint
+ ↓
+Remote Function Execution
+ ↓
+Deserialize Result (base64 → cloudpickle)
+ ↓
+Result
+```
+
+### Design Decisions
+
+#### 1. Manifest-Based Routing
+
+**Decision**: Use JSON manifest for function-to-resource mapping
+
+**Rationale**:
+- Decouples function code from deployment configuration
+- Enables environment-specific routing without code changes
+- Version control friendly
+- Dynamic routing without recompilation
+
+#### 2. Thread-Safe Async Caching
+
+**Decision**: Use `asyncio.Lock()` for directory cache synchronization
+
+**Rationale**:
+- Prevents thundering herd on cache expiry
+- Efficient - only one coroutine loads directory
+- Simple to understand and maintain
+- Follows async/await patterns
+
+#### 3. Cloudpickle Serialization
+
+**Decision**: Use cloudpickle for argument serialization
+
+**Rationale**:
+- Supports complex Python objects
+- Handles closures and lambdas (when possible)
+- Standard in Python ML/data processing
+- Works with numpy arrays, dataframes, etc.
+
+#### 4. Graceful Fallback
+
+**Decision**: Default to local execution if directory unavailable
+
+**Rationale**:
+- Maintains application resilience
+- Doesn't fail if mothership unreachable
+- Allows local testing without directory
+- Gradual degradation vs catastrophic failure
+
+#### 5. Transparent Routing
+
+**Decision**: Routing happens at wrapper level, not in business logic
+
+**Rationale**:
+- Function code unchanged
+- Testing simplified - can mock ProductionWrapper
+- Easy to add/remove routing dynamically
+- No performance overhead if not needed
+
+### Extension Points
+
+#### Adding New Serialization Formats
+
+To support formats other than cloudpickle:
+
+1. Create serializer class:
+```python
+class JsonSerializer:
+ @staticmethod
+ def serialize(obj: Any) -> str:
+ return json.dumps(obj)
+
+ @staticmethod
+ def deserialize(data: str) -> Any:
+ return json.loads(data)
+```
+
+2. Update ProductionWrapper to select serializer based on config
+3. Add tests for new format
+
+#### Adding New Directory Backends
+
+To support directories other than mothership:
+
+1. Create client class with `get_directory()` method:
+```python
+class CustomDirectoryClient:
+ async def get_directory(self) -> Dict[str, str]:
+ """Fetch directory mapping resource_config_name → endpoint_url."""
+ # Implementation specific to backend
+ return {"resource_name": "https://endpoint.url"}
+```
+
+2. Update ServiceRegistry to accept and use client in constructor:
+```python
+registry = ServiceRegistry(
+ manifest_path=Path("manifest.json"),
+ directory_client=CustomDirectoryClient(...)
+)
+```
+
+3. Update environment variable handling if needed (CustomDirectoryClient can read from env vars)
+
+#### Adding Routing Policies
+
+To implement routing logic beyond manifest:
+
+1. Create policy class:
+```python
+class RoutingPolicy:
+ async def should_route(
+ self,
+ func_name: str,
+ resource: ServerlessResource,
+ args: tuple,
+ kwargs: dict
+ ) -> bool:
+ """Determine if function should route remotely."""
+ pass
+```
+
+2. Update ProductionWrapper to use policy
+3. Add configuration for policy selection
+
+### Testing Strategy
+
+#### Unit Tests
+
+**ServiceRegistry Tests** (`tests/unit/runtime/test_service_registry.py`):
+- Manifest loading
+- Resource lookup
+- Directory caching
+- TTL expiry
+- Lock behavior under concurrency
+
+**DirectoryClient Tests** (`tests/unit/runtime/test_directory_client.py`):
+- Successful HTTP requests
+- Error handling
+- Retry logic
+- Timeout handling
+- URL validation
+
+**ProductionWrapper Tests** (`tests/unit/runtime/test_production_wrapper.py`):
+- Local execution routing
+- Remote execution routing
+- Argument serialization
+- Result deserialization
+- Error propagation
+
+#### Integration Tests
+
+**Cross-Endpoint Tests** (`tests/integration/test_cross_endpoint_routing.py`):
+- End-to-end local execution
+- End-to-end remote execution
+- Function call across endpoints
+- Error handling in real scenarios
+- Directory caching behavior
+- Serialization of complex objects
+
+#### Test Patterns
+
+```python
+# Unit test: Mock dependencies
+@pytest.mark.asyncio
+async def test_local_execution():
+ mock_registry = AsyncMock()
+ mock_registry.get_resource_for_function.return_value = None
+
+ wrapper = ProductionWrapper(mock_registry)
+ result = await wrapper.wrap_function_execution(
+ original_stub_func=mock_original,
+ func=mock_func,
+ ...
+ )
+
+ assert result == expected_result
+
+# Integration test: Real components
+@pytest.mark.asyncio
+async def test_cross_endpoint_execution():
+ registry = ServiceRegistry(manifest_path=manifest)
+ wrapper = ProductionWrapper(registry)
+
+ # Actual function call
+ result = await wrapper.wrap_function_execution(
+ original_stub_func=real_stub,
+ func=test_function,
+ ...
+ )
+
+ assert result == expected_result
+```
+
+### Debugging
+
+#### Logging
+
+Enable debug logging to trace routing decisions:
+
+```python
+import logging
+logging.basicConfig(level=logging.DEBUG)
+
+# ProductionWrapper logs
+# ServiceRegistry logs
+# DirectoryClient logs
+```
+
+#### Common Debug Scenarios
+
+**Function not routing to remote**:
+```python
+# Check manifest
+print(registry._manifest)
+
+# Check directory
+print(registry._directory)
+
+# Check resource lookup
+resource = registry.get_resource_for_function("function_name")
+print(f"Resource: {resource}")
+```
+
+**Serialization failures**:
+```python
+# Test pickling directly
+import cloudpickle
+try:
+ cloudpickle.dumps(args)
+except Exception as e:
+ print(f"Not serializable: {e}")
+```
+
+**Directory unavailable**:
+```python
+# Check environment variables
+import os
+print(f"FLASH_MOTHERSHIP_URL: {os.getenv('FLASH_MOTHERSHIP_URL')}")
+print(f"RUNPOD_ENDPOINT_ID: {os.getenv('RUNPOD_ENDPOINT_ID')}")
+
+# Check directory client directly
+client = DirectoryClient(mothership_url=...)
+endpoints = await client.get_endpoints()
+```
+
+## Key Implementation Highlights
+
+### Design Focus
+
+1. **Transparent Routing**: Functions route automatically without code changes
+2. **Graceful Degradation**: Defaults to local execution if directory unavailable
+3. **Type Safety**: Full type hints throughout for IDE support and static analysis
+4. **Thread-Safe Async**: Proper `asyncio.Lock()` usage for concurrent operations
+5. **Clear Error Hierarchy**: Custom exceptions provide actionable error context
+
+## Conclusion
+
+Cross-endpoint routing provides:
+
+- **Transparency**: Functions route automatically without manual HTTP calls
+- **Flexibility**: Manifest-based routing enables environment-specific configurations
+- **Resilience**: Graceful fallback to local execution if directory unavailable
+- **Simplicity**: No changes to function code or signatures
+- **Debuggability**: Clear error messages and logging for troubleshooting
+
+The architecture prioritizes clarity and maintainability while enabling distributed serverless applications.
diff --git a/src/tetra_rp/cli/commands/build_utils/handler_generator.py b/src/tetra_rp/cli/commands/build_utils/handler_generator.py
index c1e5d616..ef6ec318 100644
--- a/src/tetra_rp/cli/commands/build_utils/handler_generator.py
+++ b/src/tetra_rp/cli/commands/build_utils/handler_generator.py
@@ -4,7 +4,9 @@
import importlib.util
import logging
from pathlib import Path
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Union
+
+from tetra_rp.runtime.models import Manifest
logger = logging.getLogger(__name__)
@@ -38,7 +40,7 @@
class HandlerGenerator:
"""Generates handler_.py files for each resource config."""
- def __init__(self, manifest: Dict[str, Any], build_dir: Path):
+ def __init__(self, manifest: Union[Dict[str, Any], Manifest], build_dir: Path):
self.manifest = manifest
self.build_dir = build_dir
@@ -46,9 +48,21 @@ def generate_handlers(self) -> List[Path]:
"""Generate all handler files for queue-based (non-LB) resources."""
handler_paths = []
- for resource_name, resource_data in self.manifest.get("resources", {}).items():
+ # Handle both dict and Manifest types
+ resources = (
+ self.manifest.resources
+ if isinstance(self.manifest, Manifest)
+ else self.manifest.get("resources", {})
+ )
+
+ for resource_name, resource_data in resources.items():
# Skip load-balanced resources (handled by LBHandlerGenerator)
- if resource_data.get("resource_type") == "LoadBalancerSlsResource":
+ resource_type = (
+ resource_data.resource_type
+ if hasattr(resource_data, "resource_type")
+ else resource_data.get("resource_type")
+ )
+ if resource_type == "LoadBalancerSlsResource":
continue
handler_path = self._generate_handler(resource_name, resource_data)
@@ -56,21 +70,30 @@ def generate_handlers(self) -> List[Path]:
return handler_paths
- def _generate_handler(
- self, resource_name: str, resource_data: Dict[str, Any]
- ) -> Path:
+ def _generate_handler(self, resource_name: str, resource_data: Any) -> Path:
"""Generate a single handler file."""
handler_filename = f"handler_{resource_name}.py"
handler_path = self.build_dir / handler_filename
# Get timestamp from manifest
- timestamp = self.manifest.get("generated_at", "")
+ timestamp = (
+ self.manifest.generated_at
+ if isinstance(self.manifest, Manifest)
+ else self.manifest.get("generated_at", "")
+ )
+
+ # Get functions from resource (handle both dict and ResourceConfig)
+ functions = (
+ resource_data.functions
+ if hasattr(resource_data, "functions")
+ else resource_data.get("functions", [])
+ )
# Generate imports section
- imports = self._generate_imports(resource_data.get("functions", []))
+ imports = self._generate_imports(functions)
# Generate function registry
- registry = self._generate_registry(resource_data.get("functions", []))
+ registry = self._generate_registry(functions)
# Format template
handler_code = HANDLER_TEMPLATE.format(
@@ -87,7 +110,7 @@ def _generate_handler(
return handler_path
- def _generate_imports(self, functions: List[Dict[str, Any]]) -> str:
+ def _generate_imports(self, functions: List[Any]) -> str:
"""Generate import statements for functions using dynamic imports.
Uses importlib.import_module() to handle module names with invalid
@@ -98,8 +121,9 @@ def _generate_imports(self, functions: List[Dict[str, Any]]) -> str:
imports = []
for func in functions:
- module = func.get("module")
- name = func.get("name")
+ # Handle both dict and FunctionMetadata
+ module = func.module if hasattr(func, "module") else func.get("module")
+ name = func.name if hasattr(func, "name") else func.get("name")
if module and name:
# Use dynamic import to handle invalid identifiers
@@ -107,7 +131,7 @@ def _generate_imports(self, functions: List[Dict[str, Any]]) -> str:
return "\n".join(imports) if imports else "# No functions to import"
- def _generate_registry(self, functions: List[Dict[str, Any]]) -> str:
+ def _generate_registry(self, functions: List[Any]) -> str:
"""Generate function registry dictionary."""
if not functions:
return " # No functions registered"
@@ -115,7 +139,8 @@ def _generate_registry(self, functions: List[Dict[str, Any]]) -> str:
registry_lines = []
for func in functions:
- name = func.get("name")
+ # Handle both dict and FunctionMetadata
+ name = func.name if hasattr(func, "name") else func.get("name")
registry_lines.append(f' "{name}": {name},')
return "\n".join(registry_lines)
diff --git a/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py b/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py
index c7bf0ba7..f7c679c3 100644
--- a/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py
+++ b/src/tetra_rp/cli/commands/build_utils/lb_handler_generator.py
@@ -3,7 +3,9 @@
import importlib.util
import logging
from pathlib import Path
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Union
+
+from tetra_rp.runtime.models import Manifest
logger = logging.getLogger(__name__)
@@ -58,7 +60,7 @@ def ping():
class LBHandlerGenerator:
"""Generates FastAPI handlers for LoadBalancerSlsResource endpoints."""
- def __init__(self, manifest: Dict[str, Any], build_dir: Path):
+ def __init__(self, manifest: Union[Dict[str, Any], Manifest], build_dir: Path):
self.manifest = manifest
self.build_dir = build_dir
@@ -66,9 +68,20 @@ def generate_handlers(self) -> List[Path]:
"""Generate all LB handler files."""
handler_paths = []
- for resource_name, resource_data in self.manifest.get("resources", {}).items():
+ # Handle both dict and Manifest types
+ resources = (
+ self.manifest.resources
+ if isinstance(self.manifest, Manifest)
+ else self.manifest.get("resources", {})
+ )
+
+ for resource_name, resource_data in resources.items():
# Generate for both LiveLoadBalancer (local dev) and LoadBalancerSlsResource (deployed)
- resource_type = resource_data.get("resource_type")
+ resource_type = (
+ resource_data.resource_type
+ if hasattr(resource_data, "resource_type")
+ else resource_data.get("resource_type")
+ )
if resource_type not in ["LoadBalancerSlsResource", "LiveLoadBalancer"]:
continue
@@ -77,26 +90,39 @@ def generate_handlers(self) -> List[Path]:
return handler_paths
- def _generate_handler(
- self, resource_name: str, resource_data: Dict[str, Any]
- ) -> Path:
+ def _generate_handler(self, resource_name: str, resource_data: Any) -> Path:
"""Generate a single FastAPI handler file."""
handler_filename = f"handler_{resource_name}.py"
handler_path = self.build_dir / handler_filename
# Get timestamp from manifest
- timestamp = self.manifest.get("generated_at", "")
+ timestamp = (
+ self.manifest.generated_at
+ if isinstance(self.manifest, Manifest)
+ else self.manifest.get("generated_at", "")
+ )
# Determine if /execute endpoint should be included
# LiveLoadBalancer (local dev) includes /execute, deployed LoadBalancerSlsResource does not
- resource_type = resource_data.get("resource_type", "LoadBalancerSlsResource")
+ resource_type = (
+ resource_data.resource_type
+ if hasattr(resource_data, "resource_type")
+ else resource_data.get("resource_type", "LoadBalancerSlsResource")
+ )
include_execute = resource_type == "LiveLoadBalancer"
+ # Get functions from resource (handle both dict and ResourceConfig)
+ functions = (
+ resource_data.functions
+ if hasattr(resource_data, "functions")
+ else resource_data.get("functions", [])
+ )
+
# Generate imports section
- imports = self._generate_imports(resource_data.get("functions", []))
+ imports = self._generate_imports(functions)
# Generate route registry
- registry = self._generate_route_registry(resource_data.get("functions", []))
+ registry = self._generate_route_registry(functions)
# Format template
handler_code = LB_HANDLER_TEMPLATE.format(
@@ -114,11 +140,11 @@ def _generate_handler(
return handler_path
- def _generate_imports(self, functions: List[Dict[str, Any]]) -> str:
+ def _generate_imports(self, functions: List[Any]) -> str:
"""Generate import statements for functions.
Args:
- functions: List of function metadata dicts
+ functions: List of function metadata (dicts or FunctionMetadata objects)
Returns:
Import statements as string
@@ -126,15 +152,16 @@ def _generate_imports(self, functions: List[Dict[str, Any]]) -> str:
imports = []
for func in functions:
- module = func.get("module")
- name = func.get("name")
+ # Handle both dict and FunctionMetadata
+ module = func.module if hasattr(func, "module") else func.get("module")
+ name = func.name if hasattr(func, "name") else func.get("name")
if module and name:
imports.append(f"from {module} import {name}")
return "\n".join(imports) if imports else "# No functions to import"
- def _generate_route_registry(self, functions: List[Dict[str, Any]]) -> str:
+ def _generate_route_registry(self, functions: List[Any]) -> str:
"""Generate route registry for FastAPI app.
Creates mapping of (method, path) tuples to function names.
@@ -151,9 +178,16 @@ def _generate_route_registry(self, functions: List[Dict[str, Any]]) -> str:
registry_lines = []
for func in functions:
- name = func.get("name")
- method = func.get("http_method")
- path = func.get("http_path")
+ # Handle both dict and FunctionMetadata
+ name = func.name if hasattr(func, "name") else func.get("name")
+ method = (
+ func.http_method
+ if hasattr(func, "http_method")
+ else func.get("http_method")
+ )
+ path = (
+ func.http_path if hasattr(func, "http_path") else func.get("http_path")
+ )
if name and method and path:
# Create tuple key: ("GET", "/api/process")
diff --git a/src/tetra_rp/execute_class.py b/src/tetra_rp/execute_class.py
index 70db2fa7..0e301d5d 100644
--- a/src/tetra_rp/execute_class.py
+++ b/src/tetra_rp/execute_class.py
@@ -6,7 +6,6 @@
prevent memory leaks through LRU eviction.
"""
-import base64
import hashlib
import inspect
import logging
@@ -20,6 +19,8 @@
from .core.utils.constants import HASH_TRUNCATE_LENGTH, UUID_FALLBACK_LENGTH
from .core.utils.lru_cache import LRUCache
from .protos.remote_execution import FunctionRequest
+from .runtime.exceptions import SerializationError
+from .runtime.serialization import serialize_args, serialize_kwargs
from .stubs import stub_resource
log = logging.getLogger(__name__)
@@ -30,14 +31,7 @@
def serialize_constructor_args(args, kwargs):
"""Serialize constructor arguments for caching."""
- serialized_args = [
- base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8") for arg in args
- ]
- serialized_kwargs = {
- k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
- for k, v in kwargs.items()
- }
- return serialized_args, serialized_kwargs
+ return serialize_args(args), serialize_kwargs(kwargs)
def get_or_cache_class_data(
@@ -65,7 +59,7 @@ def get_or_cache_class_data(
log.debug(f"Cached class data for {cls.__name__} with key: {cache_key}")
- except (TypeError, AttributeError, OSError) as e:
+ except (TypeError, AttributeError, OSError, SerializationError) as e:
log.warning(
f"Could not serialize constructor arguments for {cls.__name__}: {e}"
)
@@ -267,14 +261,8 @@ async def method_proxy(*args, **kwargs):
cached_data = _SERIALIZED_CLASS_CACHE.get(self._cache_key)
# Serialize method arguments (these change per call, so no caching)
- method_args = [
- base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8")
- for arg in args
- ]
- method_kwargs = {
- k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
- for k, v in kwargs.items()
- }
+ method_args = serialize_args(args)
+ method_kwargs = serialize_kwargs(kwargs)
# Handle constructor args - use cached if available, else serialize fresh
if cached_data["constructor_args"] is not None:
@@ -284,14 +272,8 @@ async def method_proxy(*args, **kwargs):
else:
# Constructor args couldn't be cached due to serialization issues
# Serialize them fresh for each method call (fallback behavior)
- constructor_args = [
- base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8")
- for arg in self._constructor_args
- ]
- constructor_kwargs = {
- k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
- for k, v in self._constructor_kwargs.items()
- }
+ constructor_args = serialize_args(self._constructor_args)
+ constructor_kwargs = serialize_kwargs(self._constructor_kwargs)
request = FunctionRequest(
execution_type="class",
diff --git a/src/tetra_rp/runtime/config.py b/src/tetra_rp/runtime/config.py
new file mode 100644
index 00000000..c0efc11f
--- /dev/null
+++ b/src/tetra_rp/runtime/config.py
@@ -0,0 +1,12 @@
+"""Configuration constants for runtime module."""
+
+# HTTP client configuration
+DEFAULT_REQUEST_TIMEOUT = 10 # seconds
+DEFAULT_MAX_RETRIES = 3
+DEFAULT_BACKOFF_BASE = 2
+
+# Directory cache configuration
+DEFAULT_CACHE_TTL = 300 # seconds
+
+# Serialization limits
+MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB
diff --git a/src/tetra_rp/runtime/exceptions.py b/src/tetra_rp/runtime/exceptions.py
new file mode 100644
index 00000000..fec800fd
--- /dev/null
+++ b/src/tetra_rp/runtime/exceptions.py
@@ -0,0 +1,31 @@
+"""Custom exceptions for cross-endpoint runtime."""
+
+
+class FlashRuntimeError(Exception):
+ """Base exception for runtime errors in cross-endpoint execution."""
+
+ pass
+
+
+class RemoteExecutionError(FlashRuntimeError):
+ """Raised when remote function execution fails."""
+
+ pass
+
+
+class SerializationError(FlashRuntimeError):
+ """Raised when serialization or deserialization of arguments fails."""
+
+ pass
+
+
+class ManifestError(FlashRuntimeError):
+ """Raised when manifest is invalid, missing, or has unexpected structure."""
+
+ pass
+
+
+class ManifestServiceUnavailableError(FlashRuntimeError):
+ """Raised when manifest directory service is unavailable."""
+
+ pass
diff --git a/src/tetra_rp/runtime/generic_handler.py b/src/tetra_rp/runtime/generic_handler.py
index c3d3bf8d..0bf9aa39 100644
--- a/src/tetra_rp/runtime/generic_handler.py
+++ b/src/tetra_rp/runtime/generic_handler.py
@@ -1,13 +1,12 @@
"""Generic RunPod serverless handler factory for Flash."""
-import base64
import json
import logging
import traceback
from pathlib import Path
from typing import Any, Callable, Dict
-import cloudpickle
+from .serialization import deserialize_args, deserialize_kwargs, serialize_arg
logger = logging.getLogger(__name__)
@@ -64,13 +63,8 @@ def deserialize_arguments(job_input: Dict[str, Any]) -> tuple[list, dict]:
Returns:
Tuple of (args list, kwargs dict) deserialized from cloudpickle
"""
- args = [
- cloudpickle.loads(base64.b64decode(arg)) for arg in job_input.get("args", [])
- ]
- kwargs = {
- k: cloudpickle.loads(base64.b64decode(v))
- for k, v in job_input.get("kwargs", {}).items()
- }
+ args = deserialize_args(job_input.get("args", []))
+ kwargs = deserialize_kwargs(job_input.get("kwargs", {}))
return args, kwargs
@@ -83,7 +77,7 @@ def serialize_result(result: Any) -> str:
Returns:
Base64-encoded cloudpickle of result
"""
- return base64.b64encode(cloudpickle.dumps(result)).decode("utf-8")
+ return serialize_arg(result)
def execute_function(
diff --git a/src/tetra_rp/runtime/lb_handler.py b/src/tetra_rp/runtime/lb_handler.py
index 6f7c198c..82fec707 100644
--- a/src/tetra_rp/runtime/lb_handler.py
+++ b/src/tetra_rp/runtime/lb_handler.py
@@ -15,14 +15,18 @@
Users should NOT expose the /execute endpoint to untrusted clients.
"""
-import base64
import inspect
import logging
from typing import Any, Callable, Dict
-import cloudpickle
from fastapi import FastAPI, Request
+from .serialization import (
+ deserialize_args,
+ deserialize_kwargs,
+ serialize_arg,
+)
+
logger = logging.getLogger(__name__)
@@ -94,29 +98,15 @@ async def execute_remote_function(request: Request) -> Dict[str, Any]:
}
# Deserialize arguments
- args = []
- for arg_b64 in body.get("args", []):
- try:
- arg = cloudpickle.loads(base64.b64decode(arg_b64))
- args.append(arg)
- except Exception as e:
- logger.error(f"Failed to deserialize argument: {e}")
- return {
- "success": False,
- "error": f"Failed to deserialize argument: {e}",
- }
-
- kwargs = {}
- for key, val_b64 in body.get("kwargs", {}).items():
- try:
- val = cloudpickle.loads(base64.b64decode(val_b64))
- kwargs[key] = val
- except Exception as e:
- logger.error(f"Failed to deserialize kwarg '{key}': {e}")
- return {
- "success": False,
- "error": f"Failed to deserialize kwarg '{key}': {e}",
- }
+ try:
+ args = deserialize_args(body.get("args", []))
+ kwargs = deserialize_kwargs(body.get("kwargs", {}))
+ except Exception as e:
+ logger.error(f"Failed to deserialize arguments: {e}")
+ return {
+ "success": False,
+ "error": f"Failed to deserialize arguments: {e}",
+ }
# Execute function in isolated namespace
namespace: Dict[str, Any] = {}
@@ -160,9 +150,7 @@ async def execute_remote_function(request: Request) -> Dict[str, Any]:
# Serialize result
try:
- result_b64 = base64.b64encode(cloudpickle.dumps(result)).decode(
- "utf-8"
- )
+ result_b64 = serialize_arg(result)
return {"success": True, "result": result_b64}
except Exception as e:
logger.error(f"Failed to serialize result: {e}")
diff --git a/src/tetra_rp/runtime/manifest_client.py b/src/tetra_rp/runtime/manifest_client.py
new file mode 100644
index 00000000..bfe69ca8
--- /dev/null
+++ b/src/tetra_rp/runtime/manifest_client.py
@@ -0,0 +1,137 @@
+"""HTTP client for mothership manifest directory API."""
+
+import asyncio
+import logging
+import os
+from typing import Dict, Optional
+
+try:
+ import httpx
+except ImportError:
+ httpx = None
+
+from .config import DEFAULT_MAX_RETRIES, DEFAULT_REQUEST_TIMEOUT
+from .exceptions import ManifestServiceUnavailableError
+
+logger = logging.getLogger(__name__)
+
+
+class ManifestClient:
+ """HTTP client for querying mothership manifest directory service.
+
+ Fetches the endpoint registry that maps resource_config names to their
+ deployment URLs. This is the "manifest directory service" - an endpoint
+ registry showing where resources are deployed.
+
+ The directory maps resource_config names to their endpoint URLs.
+ Example: {"gpu_config": "https://api.runpod.io/v2/abc123"}
+ """
+
+ def __init__(
+ self,
+ mothership_url: Optional[str] = None,
+ timeout: int = DEFAULT_REQUEST_TIMEOUT,
+ max_retries: int = DEFAULT_MAX_RETRIES,
+ ):
+ """Initialize manifest client.
+
+ Args:
+ mothership_url: Base URL of mothership endpoint. Defaults to
+ FLASH_MOTHERSHIP_URL environment variable.
+ timeout: Request timeout in seconds (default: 10).
+ max_retries: Maximum retry attempts (default: 3).
+
+ Raises:
+ ValueError: If mothership_url not provided and env var not set.
+ """
+ self.mothership_url = mothership_url or os.getenv("FLASH_MOTHERSHIP_URL")
+ if not self.mothership_url:
+ raise ValueError(
+ "mothership_url required: pass mothership_url or set "
+ "FLASH_MOTHERSHIP_URL environment variable"
+ )
+
+ self.timeout = timeout
+ self.max_retries = max_retries
+ self._client: Optional[httpx.AsyncClient] = None
+
+ async def get_directory(self) -> Dict[str, str]:
+ """Fetch endpoint directory from mothership.
+
+ Returns:
+ Dictionary mapping resource_config_name → endpoint_url.
+ Example: {"gpu_config": "https://api.runpod.io/v2/abc123"}
+
+ Raises:
+ ManifestServiceUnavailableError: If manifest directory service unavailable after retries.
+ """
+ if httpx is None:
+ raise ImportError(
+ "httpx required for ManifestClient. Install with: pip install httpx"
+ )
+
+ last_exception: Optional[Exception] = None
+
+ for attempt in range(self.max_retries):
+ try:
+ client = await self._get_client()
+ response = await client.get(
+ f"{self.mothership_url}/directory",
+ timeout=self.timeout,
+ )
+
+ if response.status_code >= 400:
+ raise ManifestServiceUnavailableError(
+ f"Directory API returned {response.status_code}: "
+ f"{response.text[:200]}"
+ )
+
+ data = response.json()
+ if "directory" not in data:
+ raise ManifestServiceUnavailableError(
+ "Invalid directory response: missing 'directory' key"
+ )
+
+ directory = data["directory"]
+ logger.debug(f"Directory loaded: {len(directory)} endpoints")
+ return directory
+
+ except (
+ asyncio.TimeoutError,
+ ManifestServiceUnavailableError,
+ Exception,
+ ) as e:
+ last_exception = e
+ if attempt < self.max_retries - 1:
+ backoff = 2**attempt
+ logger.warning(
+ f"Manifest service request failed (attempt {attempt + 1}): {e}, "
+ f"retrying in {backoff}s..."
+ )
+ await asyncio.sleep(backoff)
+ continue
+
+ raise ManifestServiceUnavailableError(
+ f"Failed to fetch manifest directory after {self.max_retries} attempts: {last_exception}"
+ )
+
+ async def _get_client(self) -> httpx.AsyncClient:
+ """Get or create HTTP client with proper configuration."""
+ if self._client is None or self._client.is_closed:
+ timeout = httpx.Timeout(self.timeout)
+ self._client = httpx.AsyncClient(timeout=timeout)
+
+ return self._client
+
+ async def close(self) -> None:
+ """Close HTTP session."""
+ if self._client and not self._client.is_closed:
+ await self._client.aclose()
+
+ async def __aenter__(self):
+ """Async context manager entry."""
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ """Async context manager exit."""
+ await self.close()
diff --git a/src/tetra_rp/runtime/models.py b/src/tetra_rp/runtime/models.py
new file mode 100644
index 00000000..583bf345
--- /dev/null
+++ b/src/tetra_rp/runtime/models.py
@@ -0,0 +1,73 @@
+"""Type-safe models for manifest handling."""
+
+from dataclasses import asdict, dataclass, field
+from typing import Any, Dict, List, Optional
+
+
+@dataclass
+class FunctionMetadata:
+ """Function metadata in manifest."""
+
+ name: str
+ module: str
+ is_async: bool
+ is_class: bool = False
+ http_method: Optional[str] = None
+ http_path: Optional[str] = None
+
+
+@dataclass
+class ResourceConfig:
+ """Resource configuration in manifest."""
+
+ resource_type: str
+ handler_file: str
+ functions: List[FunctionMetadata] = field(default_factory=list)
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "ResourceConfig":
+ """Load ResourceConfig from dict."""
+ functions = [
+ FunctionMetadata(**func_data) for func_data in data.get("functions", [])
+ ]
+ return cls(
+ resource_type=data["resource_type"],
+ handler_file=data["handler_file"],
+ functions=functions,
+ )
+
+
+@dataclass
+class Manifest:
+ """Type-safe manifest structure."""
+
+ version: str
+ generated_at: str
+ project_name: str
+ function_registry: Dict[str, str]
+ resources: Dict[str, ResourceConfig]
+ routes: Optional[Dict[str, Dict[str, str]]] = None
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "Manifest":
+ """Load Manifest from JSON dict."""
+ resources = {}
+ for resource_name, resource_data in data.get("resources", {}).items():
+ resources[resource_name] = ResourceConfig.from_dict(resource_data)
+
+ return cls(
+ version=data.get("version", "1.0"),
+ generated_at=data.get("generated_at", ""),
+ project_name=data.get("project_name", ""),
+ function_registry=data.get("function_registry", {}),
+ resources=resources,
+ routes=data.get("routes"),
+ )
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert to JSON-serializable dict."""
+ result = asdict(self)
+ # Remove None routes to keep JSON clean
+ if result.get("routes") is None:
+ result.pop("routes", None)
+ return result
diff --git a/src/tetra_rp/runtime/production_wrapper.py b/src/tetra_rp/runtime/production_wrapper.py
new file mode 100644
index 00000000..65ce815d
--- /dev/null
+++ b/src/tetra_rp/runtime/production_wrapper.py
@@ -0,0 +1,265 @@
+"""Production wrapper for cross-endpoint function routing."""
+
+import logging
+from typing import Any, Callable, Dict, Optional
+
+from tetra_rp.core.resources.serverless import ServerlessResource
+
+from .exceptions import RemoteExecutionError
+from .serialization import serialize_args, serialize_kwargs
+from .service_registry import ServiceRegistry
+
+logger = logging.getLogger(__name__)
+
+
+class ProductionWrapper:
+ """Wrapper that routes function execution between endpoints.
+
+ Intercepts stub execution and determines if the call is local (execute
+ directly) or remote (call via HTTP to another endpoint).
+ """
+
+ def __init__(self, service_registry: ServiceRegistry):
+ """Initialize production wrapper.
+
+ Args:
+ service_registry: Service registry for routing decisions.
+ """
+ self.service_registry = service_registry
+ self._directory_loaded = False
+
+ async def wrap_function_execution(
+ self,
+ original_stub_func: Callable,
+ func: Callable,
+ dependencies: Optional[list],
+ system_dependencies: Optional[list],
+ accelerate_downloads: bool,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Any:
+ """Route function execution to local or remote endpoint.
+
+ Args:
+ original_stub_func: The original stubbed_resource function.
+ func: The decorated function being called.
+ dependencies: Pip dependencies (for local execution).
+ system_dependencies: System dependencies (for local execution).
+ accelerate_downloads: Download acceleration flag (for local).
+ *args: Function positional arguments.
+ **kwargs: Function keyword arguments.
+
+ Returns:
+ Function execution result.
+
+ Raises:
+ Exception: If execution fails.
+ """
+ function_name = func.__name__
+
+ # Ensure directory is loaded
+ await self.service_registry._ensure_directory_loaded()
+
+ # Determine routing
+ try:
+ resource = self.service_registry.get_resource_for_function(function_name)
+ except ValueError as e:
+ # Function not in manifest, execute locally
+ logger.debug(
+ f"Function {function_name} not in manifest: {e}, executing locally"
+ )
+ return await original_stub_func(
+ func,
+ dependencies,
+ system_dependencies,
+ accelerate_downloads,
+ *args,
+ **kwargs,
+ )
+
+ # Local execution
+ if resource is None:
+ logger.debug(f"Executing local function: {function_name}")
+ return await original_stub_func(
+ func,
+ dependencies,
+ system_dependencies,
+ accelerate_downloads,
+ *args,
+ **kwargs,
+ )
+
+ # Remote execution
+ logger.debug(f"Routing function {function_name} to remote endpoint")
+ return await self._execute_remote(
+ resource,
+ function_name,
+ args,
+ kwargs,
+ execution_type="function",
+ )
+
+ async def wrap_class_method_execution(
+ self,
+ original_method_func: Callable,
+ request: Any,
+ ) -> Any:
+ """Route class method execution to local or remote endpoint.
+
+ Args:
+ original_method_func: The original execute_class_method function.
+ request: FunctionRequest containing class and method info.
+
+ Returns:
+ Method execution result.
+
+ Raises:
+ Exception: If execution fails.
+ """
+ # Ensure directory is loaded
+ await self.service_registry._ensure_directory_loaded()
+
+ class_name = getattr(request, "class_name", None)
+
+ if not class_name:
+ # No class name, execute locally
+ return await original_method_func(request)
+
+ # Determine routing
+ try:
+ resource = self.service_registry.get_resource_for_function(class_name)
+ except ValueError:
+ # Class not in manifest, execute locally
+ logger.debug(f"Class {class_name} not in manifest, executing locally")
+ return await original_method_func(request)
+
+ # Local execution
+ if resource is None:
+ logger.debug(f"Executing local class method: {class_name}")
+ return await original_method_func(request)
+
+ # Remote execution
+ logger.debug(f"Routing class {class_name} to remote endpoint")
+
+ # Convert FunctionRequest to dict payload
+ payload = self._build_class_payload(request)
+ return await self._execute_remote(
+ resource,
+ class_name,
+ (),
+ payload.get("input", {}),
+ execution_type="class",
+ )
+
+ async def _execute_remote(
+ self,
+ resource: ServerlessResource,
+ function_name: str,
+ args: tuple,
+ kwargs: dict,
+ execution_type: str = "function",
+ ) -> Any:
+ """Execute function on remote endpoint.
+
+ Args:
+ resource: ServerlessResource with endpoint ID set.
+ function_name: Name of function/class to execute.
+ args: Positional arguments.
+ kwargs: Keyword arguments.
+ execution_type: "function" or "class".
+
+ Returns:
+ Execution result.
+
+ Raises:
+ RemoteExecutionError: If remote execution fails.
+ """
+ # Serialize arguments
+ serialized_args = serialize_args(args)
+ serialized_kwargs = serialize_kwargs(kwargs)
+
+ # Build payload matching RunPod format
+ payload = {
+ "input": {
+ "function_name": function_name,
+ "execution_type": execution_type,
+ "args": serialized_args,
+ "kwargs": serialized_kwargs,
+ }
+ }
+
+ # Execute via ServerlessResource
+ result = await resource.run_sync(payload)
+
+ # Handle response
+ if result.error:
+ raise RemoteExecutionError(
+ f"Remote execution of {function_name} failed: {result.error}"
+ )
+
+ return result.output
+
+ def _build_class_payload(self, request: Any) -> Dict[str, Any]:
+ """Build payload from FunctionRequest for class execution.
+
+ Args:
+ request: FunctionRequest object.
+
+ Returns:
+ RunPod-format payload dict.
+ """
+ # Extract request data - handle both dict and object access patterns
+ if isinstance(request, dict):
+ data = request
+ else:
+ data = (
+ request.model_dump(exclude_none=True)
+ if hasattr(request, "model_dump")
+ else {}
+ )
+
+ # Extract class execution data
+ payload = {
+ "input": {
+ "function_name": data.get("class_name"),
+ "execution_type": "class",
+ "args": data.get("args", []),
+ "kwargs": data.get("kwargs", {}),
+ "method_name": data.get("method_name"),
+ }
+ }
+
+ return payload
+
+
+# Singleton instance management
+_wrapper_instance: Optional[ProductionWrapper] = None
+
+
+def create_production_wrapper(
+ service_registry: Optional[ServiceRegistry] = None,
+) -> ProductionWrapper:
+ """Create or get singleton ProductionWrapper instance.
+
+ Args:
+ service_registry: Service registry. Creates if not provided.
+
+ Returns:
+ ProductionWrapper instance.
+ """
+ global _wrapper_instance
+
+ if _wrapper_instance is None:
+ # Create components if not provided
+ if service_registry is None:
+ service_registry = ServiceRegistry()
+
+ _wrapper_instance = ProductionWrapper(service_registry)
+
+ return _wrapper_instance
+
+
+def reset_wrapper() -> None:
+ """Reset singleton wrapper (mainly for testing)."""
+ global _wrapper_instance
+ _wrapper_instance = None
diff --git a/src/tetra_rp/runtime/serialization.py b/src/tetra_rp/runtime/serialization.py
new file mode 100644
index 00000000..c063feb3
--- /dev/null
+++ b/src/tetra_rp/runtime/serialization.py
@@ -0,0 +1,124 @@
+"""Shared serialization utilities for cloudpickle + base64 encoding."""
+
+import base64
+from typing import Any, Dict, List
+
+import cloudpickle
+
+from .exceptions import SerializationError
+
+
+def serialize_arg(arg: Any) -> str:
+ """Serialize single argument with cloudpickle + base64.
+
+ Args:
+ arg: Argument to serialize.
+
+ Returns:
+ Base64-encoded cloudpickle serialized string.
+
+ Raises:
+ SerializationError: If serialization fails.
+ """
+ try:
+ return base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8")
+ except Exception as e:
+ raise SerializationError(f"Failed to serialize argument: {e}") from e
+
+
+def serialize_args(args: tuple) -> List[str]:
+ """Serialize positional arguments.
+
+ Args:
+ args: Tuple of arguments to serialize.
+
+ Returns:
+ List of base64-encoded serialized arguments.
+
+ Raises:
+ SerializationError: If serialization fails.
+ """
+ try:
+ return [serialize_arg(arg) for arg in args]
+ except SerializationError:
+ raise
+ except Exception as e:
+ raise SerializationError(f"Failed to serialize args: {e}") from e
+
+
+def serialize_kwargs(kwargs: dict) -> Dict[str, str]:
+ """Serialize keyword arguments.
+
+ Args:
+ kwargs: Dictionary of keyword arguments.
+
+ Returns:
+ Dictionary with base64-encoded serialized values.
+
+ Raises:
+ SerializationError: If serialization fails.
+ """
+ try:
+ return {k: serialize_arg(v) for k, v in kwargs.items()}
+ except SerializationError:
+ raise
+ except Exception as e:
+ raise SerializationError(f"Failed to serialize kwargs: {e}") from e
+
+
+def deserialize_arg(arg_b64: str) -> Any:
+ """Deserialize single base64-encoded cloudpickle argument.
+
+ Args:
+ arg_b64: Base64-encoded serialized argument.
+
+ Returns:
+ Deserialized argument.
+
+ Raises:
+ SerializationError: If deserialization fails.
+ """
+ try:
+ return cloudpickle.loads(base64.b64decode(arg_b64))
+ except Exception as e:
+ raise SerializationError(f"Failed to deserialize argument: {e}") from e
+
+
+def deserialize_args(args_b64: List[str]) -> List[Any]:
+ """Deserialize list of base64-encoded arguments.
+
+ Args:
+ args_b64: List of base64-encoded serialized arguments.
+
+ Returns:
+ List of deserialized arguments.
+
+ Raises:
+ SerializationError: If deserialization fails.
+ """
+ try:
+ return [deserialize_arg(arg) for arg in args_b64]
+ except SerializationError:
+ raise
+ except Exception as e:
+ raise SerializationError(f"Failed to deserialize args: {e}") from e
+
+
+def deserialize_kwargs(kwargs_b64: Dict[str, str]) -> Dict[str, Any]:
+ """Deserialize dict of base64-encoded keyword arguments.
+
+ Args:
+ kwargs_b64: Dictionary with base64-encoded serialized values.
+
+ Returns:
+ Dictionary with deserialized values.
+
+ Raises:
+ SerializationError: If deserialization fails.
+ """
+ try:
+ return {k: deserialize_arg(v) for k, v in kwargs_b64.items()}
+ except SerializationError:
+ raise
+ except Exception as e:
+ raise SerializationError(f"Failed to deserialize kwargs: {e}") from e
diff --git a/src/tetra_rp/runtime/service_registry.py b/src/tetra_rp/runtime/service_registry.py
new file mode 100644
index 00000000..ddcbcd84
--- /dev/null
+++ b/src/tetra_rp/runtime/service_registry.py
@@ -0,0 +1,301 @@
+"""Runtime service registry for cross-endpoint function routing."""
+
+import asyncio
+import json
+import logging
+import os
+import time
+from pathlib import Path
+from typing import Dict, Optional
+from urllib.parse import urlparse
+
+from tetra_rp.core.resources.serverless import ServerlessResource
+
+from .config import DEFAULT_CACHE_TTL
+from .manifest_client import ManifestClient, ManifestServiceUnavailableError
+from .models import Manifest
+
+logger = logging.getLogger(__name__)
+
+
+class ServiceRegistry:
+ """Service discovery and routing for cross-endpoint function calls.
+
+ Loads manifest to map functions to resource configs, queries mothership
+ directory for endpoint URLs, and determines if function calls are local
+ or remote.
+ """
+
+ def __init__(
+ self,
+ manifest_path: Optional[Path] = None,
+ directory_client: Optional[ManifestClient] = None,
+ cache_ttl: int = DEFAULT_CACHE_TTL,
+ ):
+ """Initialize service registry.
+
+ Args:
+ manifest_path: Path to flash_manifest.json. Defaults to
+ FLASH_MANIFEST_PATH env var or auto-detection.
+ directory_client: Manifest service client for mothership API. If None, creates one
+ from FLASH_MOTHERSHIP_URL env var.
+ cache_ttl: Directory cache lifetime in seconds (default: 300).
+
+ Raises:
+ FileNotFoundError: If manifest_path doesn't exist.
+ ValueError: If required env vars missing for directory_client.
+ """
+ self.cache_ttl = cache_ttl
+ self._directory: Dict[str, str] = {}
+ self._directory_loaded_at = 0.0
+ self._manifest: Manifest = Manifest(
+ version="1.0",
+ generated_at="",
+ project_name="",
+ function_registry={},
+ resources={},
+ )
+ self._directory_lock = asyncio.Lock()
+
+ # Load manifest
+ self._load_manifest(manifest_path)
+
+ # Initialize manifest client
+ if directory_client is None:
+ mothership_url = os.getenv("FLASH_MOTHERSHIP_URL")
+ if mothership_url:
+ directory_client = ManifestClient(mothership_url=mothership_url)
+ else:
+ logger.warning("FLASH_MOTHERSHIP_URL not set, directory unavailable")
+ directory_client = None
+
+ self._directory_client = directory_client
+ self._current_endpoint = os.getenv("RUNPOD_ENDPOINT_ID")
+
+ def _load_manifest(self, manifest_path: Optional[Path]) -> None:
+ """Load flash_manifest.json.
+
+ Args:
+ manifest_path: Explicit path to manifest. Tries env var and
+ auto-detection if not provided.
+
+ Raises:
+ FileNotFoundError: If manifest not found.
+ """
+ paths_to_try = []
+
+ # Explicit path
+ if manifest_path:
+ paths_to_try.append(manifest_path)
+
+ # Environment variable
+ env_path = os.getenv("FLASH_MANIFEST_PATH")
+ if env_path:
+ paths_to_try.append(Path(env_path))
+
+ # Auto-detection: same directory as this file, or cwd
+ paths_to_try.extend(
+ [
+ Path(__file__).parent.parent.parent / "flash_manifest.json",
+ Path.cwd() / "flash_manifest.json",
+ ]
+ )
+
+ # Try each path
+ for path in paths_to_try:
+ if path and path.exists():
+ try:
+ with open(path) as f:
+ manifest_dict = json.load(f)
+ self._manifest = Manifest.from_dict(manifest_dict)
+ logger.debug(f"Manifest loaded from {path}")
+ return
+ except Exception as e:
+ logger.warning(f"Failed to load manifest from {path}: {e}")
+ continue
+
+ # No manifest found - log warning but don't fail
+ logger.warning(
+ "flash_manifest.json not found. Cross-endpoint routing disabled. "
+ "Manifest is required for routing functions between endpoints."
+ )
+ self._manifest = Manifest(
+ version="1.0",
+ generated_at="",
+ project_name="",
+ function_registry={},
+ resources={},
+ )
+
+ async def _ensure_directory_loaded(self) -> None:
+ """Load directory from mothership if cache expired or not loaded."""
+ async with self._directory_lock:
+ now = time.time()
+ cache_age = now - self._directory_loaded_at
+
+ if cache_age > self.cache_ttl:
+ if self._directory_client is None:
+ logger.debug("Directory client not available, skipping refresh")
+ return
+
+ try:
+ self._directory = await self._directory_client.get_directory()
+ self._directory_loaded_at = now
+ logger.debug(
+ f"Directory loaded: {len(self._directory)} endpoints, "
+ f"cache TTL {self.cache_ttl}s"
+ )
+ except ManifestServiceUnavailableError as e:
+ logger.warning(
+ f"Failed to load manifest directory: {e}. "
+ f"Cross-endpoint routing unavailable."
+ )
+ self._directory = {}
+
+ def get_endpoint_for_function(self, function_name: str) -> Optional[str]:
+ """Get endpoint URL for a function.
+
+ Determines if function is local (same endpoint) or remote (different
+ endpoint), returning None for local and URL for remote.
+
+ Args:
+ function_name: Name of the function to route.
+
+ Returns:
+ Endpoint URL if function is remote, None if local.
+
+ Raises:
+ ValueError: If function not in manifest.
+ """
+ function_registry = self._manifest.function_registry
+
+ if function_name not in function_registry:
+ raise ValueError(
+ f"Function '{function_name}' not found in manifest. "
+ f"Available functions: {list(function_registry.keys())}"
+ )
+
+ resource_config_name = function_registry[function_name]
+
+ # Check if this is the current endpoint (local)
+ if resource_config_name == self._current_endpoint:
+ return None
+
+ # Check directory for remote endpoint URL
+ endpoint_url = self._directory.get(resource_config_name)
+ if not endpoint_url:
+ logger.debug(
+ f"Endpoint URL for '{resource_config_name}' not in directory. "
+ f"Directory has: {list(self._directory.keys())}"
+ )
+
+ return endpoint_url
+
+ def get_resource_for_function(
+ self, function_name: str
+ ) -> Optional[ServerlessResource]:
+ """Get ServerlessResource for a function.
+
+ Creates a ServerlessResource with the correct endpoint ID if the function
+ is remote, returns None if local.
+
+ Args:
+ function_name: Name of the function to route.
+
+ Returns:
+ ServerlessResource with ID set if function is remote
+ None if function runs on current endpoint
+
+ Raises:
+ ValueError: If function not in manifest.
+ """
+ endpoint_url = self.get_endpoint_for_function(function_name)
+
+ if endpoint_url is None:
+ return None # Local function
+
+ # Extract endpoint ID from URL (format: https://api.runpod.io/v2/{endpoint_id})
+ try:
+ parsed = urlparse(endpoint_url)
+ # Get the last path component (the endpoint ID)
+ path_parts = parsed.path.rstrip("/").split("/")
+ endpoint_id = path_parts[-1] if path_parts else ""
+
+ if not endpoint_id:
+ raise ValueError(
+ f"Invalid endpoint URL format: {endpoint_url} - no endpoint ID found"
+ )
+ except Exception as e:
+ raise ValueError(
+ f"Failed to parse endpoint URL '{endpoint_url}': {e}"
+ ) from e
+
+ # Create and return ServerlessResource
+ resource = ServerlessResource(name=f"remote_{function_name}")
+ resource.id = endpoint_id
+
+ return resource
+
+ def is_local_function(self, function_name: str) -> bool:
+ """Check if function executes on current endpoint.
+
+ Args:
+ function_name: Name of the function.
+
+ Returns:
+ True if function is local, False if remote or not found.
+ """
+ try:
+ endpoint_url = self.get_endpoint_for_function(function_name)
+ return endpoint_url is None
+ except ValueError:
+ # Function not in manifest, assume local (will execute and fail)
+ return True
+
+ def get_current_endpoint_id(self) -> Optional[str]:
+ """Get ID of current endpoint from environment.
+
+ Returns:
+ Endpoint ID from RUNPOD_ENDPOINT_ID, or None if not set.
+ """
+ return self._current_endpoint
+
+ def refresh_directory(self) -> None:
+ """Force refresh directory from mothership on next access."""
+ self._directory_loaded_at = 0
+
+ def get_manifest(self) -> Manifest:
+ """Get loaded manifest.
+
+ Returns:
+ Loaded Manifest object.
+ """
+ return self._manifest
+
+ def get_all_resources(self) -> Dict[str, Dict]:
+ """Get all resource configs from manifest.
+
+ Returns:
+ Dictionary of resource configs as dictionaries.
+ """
+ from dataclasses import asdict
+
+ return {
+ name: asdict(config) for name, config in self._manifest.resources.items()
+ }
+
+ def get_resource_functions(self, resource_name: str) -> list:
+ """Get list of functions for a resource.
+
+ Args:
+ resource_name: Name of the resource config.
+
+ Returns:
+ List of function metadata dictionaries.
+ """
+ resource = self._manifest.resources.get(resource_name)
+ if not resource:
+ return []
+ from dataclasses import asdict
+
+ return [asdict(func) for func in resource.functions]
diff --git a/src/tetra_rp/stubs/live_serverless.py b/src/tetra_rp/stubs/live_serverless.py
index 933c6c92..256e22d7 100644
--- a/src/tetra_rp/stubs/live_serverless.py
+++ b/src/tetra_rp/stubs/live_serverless.py
@@ -13,6 +13,7 @@
FunctionResponse,
RemoteExecutorStub,
)
+from ..runtime.serialization import serialize_args, serialize_kwargs
log = logging.getLogger(__name__)
@@ -100,14 +101,9 @@ def prepare_request(
# Serialize arguments using cloudpickle
if args:
- request["args"] = [
- base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8") for arg in args
- ]
+ request["args"] = serialize_args(args)
if kwargs:
- request["kwargs"] = {
- k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
- for k, v in kwargs.items()
- }
+ request["kwargs"] = serialize_kwargs(kwargs)
return FunctionRequest(**request)
diff --git a/src/tetra_rp/stubs/load_balancer_sls.py b/src/tetra_rp/stubs/load_balancer_sls.py
index 1edd1a90..61edcb3f 100644
--- a/src/tetra_rp/stubs/load_balancer_sls.py
+++ b/src/tetra_rp/stubs/load_balancer_sls.py
@@ -4,15 +4,18 @@
via direct HTTP calls instead of queue-based job submission.
"""
-import base64
import inspect
import logging
from typing import Any, Callable, Dict, List, Optional
import httpx
-import cloudpickle
from tetra_rp.core.utils.http import get_authenticated_httpx_client
+from tetra_rp.runtime.serialization import (
+ deserialize_arg,
+ serialize_args,
+ serialize_kwargs,
+)
from .live_serverless import get_function_source
log = logging.getLogger(__name__)
@@ -188,18 +191,11 @@ def _prepare_request(
# Serialize arguments using cloudpickle + base64
if args:
- serialized_args = [
- base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8") for arg in args
- ]
- request["args"] = serialized_args
+ request["args"] = serialize_args(args)
log.debug(f"Serialized {len(args)} positional args for {func.__name__}")
if kwargs:
- serialized_kwargs = {
- k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
- for k, v in kwargs.items()
- }
- request["kwargs"] = serialized_kwargs
+ request["kwargs"] = serialize_kwargs(kwargs)
log.debug(f"Serialized {len(kwargs)} keyword args for {func.__name__}")
return request
@@ -348,7 +344,7 @@ def _handle_response(self, response: Dict[str, Any]) -> Any:
raise ValueError("Response marked success but result is None")
try:
- result = cloudpickle.loads(base64.b64decode(result_b64))
+ result = deserialize_arg(result_b64)
log.debug(
f"Successfully deserialized response result (type={type(result).__name__})"
)
diff --git a/src/tetra_rp/stubs/registry.py b/src/tetra_rp/stubs/registry.py
index c6363726..674e0085 100644
--- a/src/tetra_rp/stubs/registry.py
+++ b/src/tetra_rp/stubs/registry.py
@@ -1,4 +1,5 @@
import logging
+import os
from functools import singledispatch
from ..core.resources import (
@@ -56,6 +57,46 @@ async def execute_class_method(request):
response = await stub.ExecuteFunction(request)
return stub.handle_response(response)
+ # Inject ProductionWrapper if in production mode
+ if os.getenv("RUNPOD_ENDPOINT_ID"):
+ try:
+ from ..runtime.production_wrapper import create_production_wrapper
+
+ wrapper = create_production_wrapper()
+ original_stubbed = stubbed_resource
+ original_class_method = execute_class_method
+
+ async def wrapped_stubbed(
+ func,
+ dependencies,
+ system_dependencies,
+ accelerate_downloads,
+ *args,
+ **kwargs,
+ ):
+ return await wrapper.wrap_function_execution(
+ original_stubbed,
+ func,
+ dependencies,
+ system_dependencies,
+ accelerate_downloads,
+ *args,
+ **kwargs,
+ )
+
+ async def wrapped_class_method(request):
+ return await wrapper.wrap_class_method_execution(
+ original_class_method, request
+ )
+
+ stubbed_resource = wrapped_stubbed
+ execute_class_method = wrapped_class_method
+
+ except ImportError:
+ log.warning(
+ "ProductionWrapper not available, cross-endpoint routing disabled"
+ )
+
# Attach the method to the function
stubbed_resource.execute_class_method = execute_class_method
diff --git a/tests/conftest.py b/tests/conftest.py
index 120b2e4f..af5f4b6b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -9,12 +9,17 @@
- Logger suppression
"""
+import gc
+import threading
from pathlib import Path
from typing import Any, Dict
from unittest.mock import AsyncMock, Mock
import pytest
+from tetra_rp.core.resources.resource_manager import ResourceManager
+from tetra_rp.core.utils.singleton import SingletonMixin
+
@pytest.fixture
def sample_gpu_config() -> Dict[str, Any]:
@@ -172,13 +177,6 @@ def reset_singletons():
This fixture runs automatically for all tests to ensure
clean state between test executions.
"""
- import gc
- import threading
-
- # Import here to avoid circular dependencies
- from tetra_rp.core.resources.resource_manager import ResourceManager
- from tetra_rp.core.utils.singleton import SingletonMixin
-
# Patch cloudpickle to handle threading.Lock objects that may be left over
# from previous tests. This prevents "cannot pickle '_thread.lock'" errors
# when test pollution causes old lock instances to be in the object graph.
diff --git a/tests/integration/test_class_execution_integration.py b/tests/integration/test_class_execution_integration.py
index 39111734..374e32af 100644
--- a/tests/integration/test_class_execution_integration.py
+++ b/tests/integration/test_class_execution_integration.py
@@ -640,11 +640,15 @@ async def mock_ensure_initialized():
):
# The error should occur during method call when trying to serialize
# Mock cloudpickle.dumps to raise an error
+ from tetra_rp.runtime.exceptions import SerializationError
+
with patch(
- "tetra_rp.execute_class.cloudpickle.dumps",
+ "tetra_rp.runtime.serialization.cloudpickle.dumps",
side_effect=TypeError("Can't pickle file objects"),
):
- with pytest.raises(TypeError, match="Can't pickle file objects"):
+ with pytest.raises(
+ SerializationError, match="Can't pickle file objects"
+ ):
await instance.process_file()
@pytest.mark.asyncio
diff --git a/tests/integration/test_cross_endpoint_routing.py b/tests/integration/test_cross_endpoint_routing.py
new file mode 100644
index 00000000..1b67967e
--- /dev/null
+++ b/tests/integration/test_cross_endpoint_routing.py
@@ -0,0 +1,326 @@
+"""Integration tests for cross-endpoint routing."""
+
+import json
+import tempfile
+from pathlib import Path
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+from tetra_rp.runtime.manifest_client import ManifestClient
+from tetra_rp.runtime.production_wrapper import (
+ ProductionWrapper,
+ create_production_wrapper,
+ reset_wrapper,
+)
+from tetra_rp.runtime.service_registry import ServiceRegistry
+
+
+class TestCrossEndpointRoutingIntegration:
+ """Integration tests for full cross-endpoint routing flow."""
+
+ @pytest.fixture(autouse=True)
+ def cleanup(self):
+ """Clean up wrapper singleton between tests."""
+ yield
+ reset_wrapper()
+
+ @pytest.fixture
+ def manifest(self):
+ """Sample manifest with multiple endpoints."""
+ return {
+ "version": "1.0",
+ "project_name": "integration_test",
+ "function_registry": {
+ "gpu_task": "gpu_config",
+ "cpu_task": "cpu_config",
+ "preprocess": "cpu_config",
+ },
+ "resources": {
+ "gpu_config": {
+ "resource_type": "LiveServerless",
+ "handler_file": "handler_gpu.py",
+ "functions": [
+ {"name": "gpu_task", "module": "workers.gpu", "is_async": True}
+ ],
+ },
+ "cpu_config": {
+ "resource_type": "LiveServerless",
+ "handler_file": "handler_cpu.py",
+ "functions": [
+ {
+ "name": "cpu_task",
+ "module": "workers.cpu",
+ "is_async": False,
+ },
+ {
+ "name": "preprocess",
+ "module": "workers.cpu",
+ "is_async": False,
+ },
+ ],
+ },
+ },
+ }
+
+ @pytest.mark.asyncio
+ async def test_local_function_execution(self, manifest):
+ """Test that local function executes without remote call."""
+ # Current endpoint is GPU
+ with patch.dict(
+ "os.environ",
+ {
+ "RUNPOD_ENDPOINT_ID": "gpu_config",
+ "FLASH_MOTHERSHIP_URL": "https://mothership.example.com",
+ },
+ ):
+ directory = {
+ "gpu_config": "https://gpu.example.com",
+ "cpu_config": "https://cpu.example.com",
+ }
+
+ with tempfile.NamedTemporaryFile(
+ mode="w", suffix=".json", delete=False
+ ) as f:
+ json.dump(manifest, f)
+ manifest_path = Path(f.name)
+
+ try:
+ registry = ServiceRegistry(manifest_path=manifest_path)
+
+ mock_dir_client = AsyncMock(spec=ManifestClient)
+ mock_dir_client.get_directory.return_value = directory
+
+ registry._directory_client = mock_dir_client
+ registry._directory = directory
+ registry._directory_loaded_at = float("inf")
+
+ wrapper = ProductionWrapper(registry)
+
+ async def gpu_task(x):
+ return x * 2
+
+ original_stub = AsyncMock()
+ original_stub.return_value = 42
+
+ await wrapper.wrap_function_execution(
+ original_stub,
+ gpu_task,
+ None,
+ None,
+ True,
+ 5,
+ )
+
+ original_stub.assert_called_once()
+
+ finally:
+ manifest_path.unlink()
+
+ @pytest.mark.asyncio
+ async def test_remote_function_execution_routing(self, manifest):
+ """Test that remote function is routed via ServerlessResource."""
+ # Current endpoint is GPU, calling CPU function
+ with patch.dict(
+ "os.environ",
+ {
+ "RUNPOD_ENDPOINT_ID": "gpu_config",
+ "FLASH_MOTHERSHIP_URL": "https://mothership.example.com",
+ },
+ ):
+ directory = {
+ "gpu_config": "https://gpu.example.com",
+ "cpu_config": "https://cpu.example.com",
+ }
+
+ with tempfile.NamedTemporaryFile(
+ mode="w", suffix=".json", delete=False
+ ) as f:
+ json.dump(manifest, f)
+ manifest_path = Path(f.name)
+
+ try:
+ registry = ServiceRegistry(manifest_path=manifest_path)
+ mock_dir_client = AsyncMock(spec=ManifestClient)
+ mock_dir_client.get_directory.return_value = directory
+ registry._directory_client = mock_dir_client
+ registry._directory = directory
+ registry._directory_loaded_at = float("inf")
+
+ # Mock ServerlessResource
+ mock_resource = AsyncMock()
+ mock_resource.run_sync = AsyncMock()
+ mock_resource.run_sync.return_value = MagicMock(
+ error="", output="processed"
+ )
+
+ wrapper = ProductionWrapper(registry)
+
+ # Mock get_resource_for_function to return our mock resource
+ with patch.object(
+ registry, "get_resource_for_function", return_value=mock_resource
+ ):
+
+ async def cpu_task(x):
+ return x * 3
+
+ original_stub = AsyncMock()
+
+ result = await wrapper.wrap_function_execution(
+ original_stub,
+ cpu_task,
+ None,
+ None,
+ True,
+ 10,
+ )
+
+ original_stub.assert_not_called()
+ mock_resource.run_sync.assert_called_once()
+ assert result == "processed"
+
+ finally:
+ manifest_path.unlink()
+
+ @pytest.mark.asyncio
+ async def test_directory_loading_on_demand(self, manifest):
+ """Test that directory is loaded on-demand before routing decision."""
+ with patch.dict(
+ "os.environ",
+ {
+ "RUNPOD_ENDPOINT_ID": "gpu_config",
+ "FLASH_MOTHERSHIP_URL": "https://mothership.example.com",
+ },
+ ):
+ directory = {
+ "gpu_config": "https://gpu.example.com",
+ "cpu_config": "https://cpu.example.com",
+ }
+
+ with tempfile.NamedTemporaryFile(
+ mode="w", suffix=".json", delete=False
+ ) as f:
+ json.dump(manifest, f)
+ manifest_path = Path(f.name)
+
+ try:
+ registry = ServiceRegistry(manifest_path=manifest_path)
+ mock_dir_client = AsyncMock(spec=ManifestClient)
+ mock_dir_client.get_directory.return_value = directory
+ registry._directory_client = mock_dir_client
+
+ assert registry._directory == {}
+
+ wrapper = ProductionWrapper(registry)
+
+ async def cpu_task(x):
+ return x
+
+ original_stub = AsyncMock()
+
+ # Mock get_resource_for_function to return a mock resource
+ mock_resource = AsyncMock()
+ mock_resource.run_sync = AsyncMock()
+ mock_resource.run_sync.return_value = MagicMock(error="", output=None)
+
+ with patch.object(
+ registry, "get_resource_for_function", return_value=mock_resource
+ ):
+ await wrapper.wrap_function_execution(
+ original_stub, cpu_task, None, None, True
+ )
+
+ assert len(registry._directory) > 0
+ assert registry._directory["gpu_config"] == "https://gpu.example.com"
+
+ finally:
+ manifest_path.unlink()
+
+ @pytest.mark.asyncio
+ async def test_error_handling_in_remote_execution(self, manifest):
+ """Test that errors from remote execution are properly propagated."""
+ with patch.dict(
+ "os.environ",
+ {
+ "RUNPOD_ENDPOINT_ID": "gpu_config",
+ "FLASH_MOTHERSHIP_URL": "https://mothership.example.com",
+ },
+ ):
+ directory = {
+ "gpu_config": "https://gpu.example.com",
+ "cpu_config": "https://cpu.example.com",
+ }
+
+ with tempfile.NamedTemporaryFile(
+ mode="w", suffix=".json", delete=False
+ ) as f:
+ json.dump(manifest, f)
+ manifest_path = Path(f.name)
+
+ try:
+ registry = ServiceRegistry(manifest_path=manifest_path)
+ mock_dir_client = AsyncMock(spec=ManifestClient)
+ mock_dir_client.get_directory.return_value = directory
+ registry._directory_client = mock_dir_client
+ registry._directory = directory
+ registry._directory_loaded_at = float("inf")
+
+ # Mock ServerlessResource that returns error
+ mock_resource = AsyncMock()
+ mock_resource.run_sync = AsyncMock()
+ mock_resource.run_sync.return_value = MagicMock(
+ success=False, error="Remote function failed: ValueError"
+ )
+
+ wrapper = ProductionWrapper(registry)
+
+ with patch.object(
+ registry, "get_resource_for_function", return_value=mock_resource
+ ):
+
+ async def cpu_task():
+ pass
+
+ original_stub = AsyncMock()
+
+ with pytest.raises(Exception, match="Remote execution.*failed"):
+ await wrapper.wrap_function_execution(
+ original_stub, cpu_task, None, None, True
+ )
+
+ finally:
+ manifest_path.unlink()
+
+ def test_factory_creates_complete_system(self):
+ """Test that factory creates fully integrated system."""
+ manifest = {
+ "version": "1.0",
+ "project_name": "test",
+ "function_registry": {"task": "resource1"},
+ "resources": {
+ "resource1": {
+ "functions": [{"name": "task", "module": "m", "is_async": True}]
+ }
+ },
+ }
+
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
+ json.dump(manifest, f)
+ manifest_path = Path(f.name)
+
+ try:
+ with patch.dict(
+ "os.environ",
+ {
+ "RUNPOD_ENDPOINT_ID": "resource1",
+ "FLASH_MOTHERSHIP_URL": "https://mothership.example.com",
+ },
+ ):
+ wrapper = create_production_wrapper()
+
+ assert wrapper.service_registry is not None
+ assert isinstance(wrapper.service_registry, ServiceRegistry)
+
+ finally:
+ manifest_path.unlink()
+ reset_wrapper()
diff --git a/tests/unit/runtime/test_manifest_client.py b/tests/unit/runtime/test_manifest_client.py
new file mode 100644
index 00000000..27bb12cc
--- /dev/null
+++ b/tests/unit/runtime/test_manifest_client.py
@@ -0,0 +1,186 @@
+"""Tests for ManifestClient."""
+
+import asyncio
+import os
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+from tetra_rp.runtime.manifest_client import (
+ ManifestClient,
+ ManifestServiceUnavailableError,
+)
+
+
+class TestManifestClient:
+ """Test ManifestClient functionality."""
+
+ @pytest.fixture
+ def mock_response(self):
+ """Mock successful HTTP response."""
+ response = MagicMock()
+ response.status_code = 200
+ response.json.return_value = {
+ "directory": {
+ "gpu_config": "https://api.runpod.io/v2/gpu123",
+ "cpu_config": "https://api.runpod.io/v2/cpu456",
+ },
+ "updated_at": "2025-01-03T12:00:00Z",
+ }
+ return response
+
+ def test_init_with_url(self):
+ """Test initialization with explicit URL."""
+ client = ManifestClient(mothership_url="https://mothership.example.com")
+ assert client.mothership_url == "https://mothership.example.com"
+
+ def test_init_from_env(self):
+ """Test initialization from environment variable."""
+ with patch.dict(os.environ, {"FLASH_MOTHERSHIP_URL": "https://from-env.com"}):
+ client = ManifestClient()
+ assert client.mothership_url == "https://from-env.com"
+
+ def test_init_missing_url(self):
+ """Test initialization fails without URL."""
+ with patch.dict(os.environ, {}, clear=True):
+ with pytest.raises(ValueError, match="mothership_url required"):
+ ManifestClient()
+
+ def test_init_explicit_over_env(self):
+ """Test explicit URL takes precedence over env var."""
+ with patch.dict(os.environ, {"FLASH_MOTHERSHIP_URL": "https://env.com"}):
+ client = ManifestClient(mothership_url="https://explicit.com")
+ assert client.mothership_url == "https://explicit.com"
+
+ @pytest.mark.asyncio
+ async def test_get_directory_success(self, mock_response):
+ """Test successful directory fetch."""
+ client = ManifestClient(mothership_url="https://mothership.example.com")
+
+ with patch("tetra_rp.runtime.manifest_client.httpx"):
+ mock_client = AsyncMock()
+ mock_client.__aenter__.return_value = mock_client
+ mock_client.get.return_value = mock_response
+
+ with patch.object(client, "_get_client", return_value=mock_client):
+ directory = await client.get_directory()
+
+ assert directory == {
+ "gpu_config": "https://api.runpod.io/v2/gpu123",
+ "cpu_config": "https://api.runpod.io/v2/cpu456",
+ }
+
+ @pytest.mark.asyncio
+ async def test_get_directory_http_error(self):
+ """Test handling of HTTP errors."""
+ client = ManifestClient(mothership_url="https://mothership.example.com")
+
+ response = MagicMock()
+ response.status_code = 500
+ response.text = "Internal server error"
+
+ with patch.object(client, "_get_client") as mock_get_client:
+ mock_http_client = AsyncMock()
+ mock_http_client.get.return_value = response
+ mock_http_client.is_closed = False
+ mock_get_client.return_value = mock_http_client
+
+ with pytest.raises(ManifestServiceUnavailableError, match="500"):
+ await client.get_directory()
+
+ @pytest.mark.asyncio
+ async def test_get_directory_timeout(self):
+ """Test handling of request timeout."""
+ client = ManifestClient(
+ mothership_url="https://mothership.example.com", timeout=0.1
+ )
+
+ with patch.object(client, "_get_client") as mock_get_client:
+ mock_http_client = AsyncMock()
+ mock_http_client.get.side_effect = asyncio.TimeoutError("Timed out")
+ mock_http_client.is_closed = False
+ mock_get_client.return_value = mock_http_client
+
+ with pytest.raises(
+ ManifestServiceUnavailableError, match="after \\d+ attempts"
+ ):
+ await client.get_directory()
+
+ @pytest.mark.asyncio
+ async def test_get_directory_retry(self):
+ """Test retry logic on transient failure."""
+ client = ManifestClient(
+ mothership_url="https://mothership.example.com", max_retries=3
+ )
+
+ response = MagicMock()
+ response.status_code = 200
+ response.json.return_value = {"directory": {"gpu": "https://gpu.example.com"}}
+
+ with patch.object(client, "_get_client") as mock_get_client:
+ mock_http_client = AsyncMock()
+
+ # First two calls fail, third succeeds
+ mock_http_client.get.side_effect = [
+ Exception("Network error"),
+ Exception("Network error"),
+ response,
+ ]
+ mock_http_client.is_closed = False
+ mock_get_client.return_value = mock_http_client
+
+ with patch(
+ "tetra_rp.runtime.manifest_client.asyncio.sleep",
+ new_callable=AsyncMock,
+ ):
+ directory = await client.get_directory()
+ assert directory == {"gpu": "https://gpu.example.com"}
+ assert mock_http_client.get.call_count == 3
+
+ @pytest.mark.asyncio
+ async def test_get_directory_exhaust_retries(self):
+ """Test failure after exhausting retries."""
+ client = ManifestClient(
+ mothership_url="https://mothership.example.com", max_retries=2
+ )
+
+ with patch.object(client, "_get_client") as mock_get_client:
+ mock_http_client = AsyncMock()
+ mock_http_client.get.side_effect = Exception("Always fails")
+ mock_http_client.is_closed = False
+ mock_get_client.return_value = mock_http_client
+
+ with patch(
+ "tetra_rp.runtime.manifest_client.asyncio.sleep",
+ new_callable=AsyncMock,
+ ):
+ with pytest.raises(
+ ManifestServiceUnavailableError, match="after 2 attempts"
+ ):
+ await client.get_directory()
+
+ @pytest.mark.asyncio
+ async def test_context_manager(self):
+ """Test async context manager."""
+ client = ManifestClient(mothership_url="https://mothership.example.com")
+
+ with patch.object(client, "close", new_callable=AsyncMock) as mock_close:
+ async with client:
+ pass
+
+ mock_close.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_close(self):
+ """Test proper cleanup of HTTP client."""
+ client = ManifestClient(mothership_url="https://mothership.example.com")
+
+ with patch("tetra_rp.runtime.manifest_client.httpx"):
+ mock_http_client = AsyncMock()
+ mock_http_client.is_closed = False
+
+ with patch.object(client, "_get_client", return_value=mock_http_client):
+ client._client = mock_http_client
+ await client.close()
+
+ mock_http_client.aclose.assert_called_once()
diff --git a/tests/unit/runtime/test_production_wrapper.py b/tests/unit/runtime/test_production_wrapper.py
new file mode 100644
index 00000000..cc628047
--- /dev/null
+++ b/tests/unit/runtime/test_production_wrapper.py
@@ -0,0 +1,302 @@
+"""Tests for ProductionWrapper."""
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+from tetra_rp.runtime.production_wrapper import (
+ ProductionWrapper,
+ create_production_wrapper,
+ reset_wrapper,
+)
+from tetra_rp.runtime.service_registry import ServiceRegistry
+
+
+class TestProductionWrapper:
+ """Test ProductionWrapper routing logic."""
+
+ @pytest.fixture
+ def mock_registry(self):
+ """Mock service registry."""
+ registry = AsyncMock(spec=ServiceRegistry)
+ registry._ensure_directory_loaded = AsyncMock()
+ return registry
+
+ @pytest.fixture
+ def wrapper(self, mock_registry):
+ """Create wrapper with mocked dependencies."""
+ return ProductionWrapper(mock_registry)
+
+ @pytest.fixture
+ def sample_function(self):
+ """Sample function for testing."""
+
+ async def test_func(x, y):
+ return x + y
+
+ return test_func
+
+ @pytest.fixture
+ def original_stub(self):
+ """Mock original stub function."""
+ return AsyncMock()
+
+ @pytest.mark.asyncio
+ async def test_wrap_function_local_execution(
+ self, wrapper, mock_registry, original_stub, sample_function
+ ):
+ """Test routing local function to original stub."""
+ mock_registry.get_resource_for_function.return_value = None
+
+ await wrapper.wrap_function_execution(
+ original_stub,
+ sample_function,
+ None, # dependencies
+ None, # system_dependencies
+ True, # accelerate_downloads
+ 1,
+ 2,
+ key="value",
+ )
+
+ # Should call original stub
+ original_stub.assert_called_once()
+ call_args = original_stub.call_args
+ assert call_args[0][0] == sample_function
+ assert call_args[0][4] == 1 # First arg
+
+ @pytest.mark.asyncio
+ async def test_wrap_function_remote_execution(
+ self, wrapper, mock_registry, original_stub, sample_function
+ ):
+ """Test routing remote function via ServerlessResource."""
+ mock_resource = AsyncMock()
+ mock_resource.run_sync = AsyncMock()
+ mock_resource.run_sync.return_value = MagicMock(error="", output=42)
+
+ mock_registry.get_resource_for_function.return_value = mock_resource
+
+ result = await wrapper.wrap_function_execution(
+ original_stub,
+ sample_function,
+ None, # dependencies
+ None, # system_dependencies
+ True, # accelerate_downloads
+ 1,
+ 2,
+ )
+
+ assert result == 42
+ # Should NOT call original stub
+ original_stub.assert_not_called()
+ # Should call ServerlessResource.run_sync()
+ mock_resource.run_sync.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_wrap_function_not_in_manifest(
+ self, wrapper, mock_registry, original_stub, sample_function
+ ):
+ """Test function not found in manifest executes locally."""
+ mock_registry.get_resource_for_function.side_effect = ValueError(
+ "Function not found"
+ )
+
+ await wrapper.wrap_function_execution(
+ original_stub,
+ sample_function,
+ None, # dependencies
+ None, # system_dependencies
+ True, # accelerate_downloads
+ 1,
+ 2,
+ )
+
+ # Should call original stub
+ original_stub.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_wrap_function_remote_error(
+ self, wrapper, mock_registry, original_stub, sample_function
+ ):
+ """Test error handling for failed remote execution."""
+ mock_resource = AsyncMock()
+ mock_resource.run_sync = AsyncMock()
+ mock_resource.run_sync.return_value = MagicMock(error="Remote execution failed")
+
+ mock_registry.get_resource_for_function.return_value = mock_resource
+
+ with pytest.raises(Exception, match="Remote execution failed"):
+ await wrapper.wrap_function_execution(
+ original_stub,
+ sample_function,
+ dependencies=None,
+ system_dependencies=None,
+ accelerate_downloads=True,
+ )
+
+ @pytest.mark.asyncio
+ async def test_wrap_function_loads_directory(self, wrapper, mock_registry):
+ """Test that directory is loaded before routing decision."""
+ mock_registry.get_resource_for_function.return_value = None
+
+ async def sample_func():
+ pass
+
+ original_stub = AsyncMock()
+ await wrapper.wrap_function_execution(
+ original_stub, sample_func, None, None, True
+ )
+
+ # Should ensure directory is loaded
+ mock_registry._ensure_directory_loaded.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_wrap_class_method_local(self, wrapper, mock_registry, original_stub):
+ """Test routing local class method."""
+ request = MagicMock()
+ request.class_name = "MyClass"
+
+ mock_registry.get_resource_for_function.return_value = None
+
+ await wrapper.wrap_class_method_execution(original_stub, request)
+
+ # Should call original
+ original_stub.assert_called_once_with(request)
+
+ @pytest.mark.asyncio
+ async def test_wrap_class_method_remote(
+ self, wrapper, mock_registry, original_stub
+ ):
+ """Test routing remote class method."""
+ request = MagicMock()
+ request.class_name = "MyClass"
+ request.method_name = "process"
+ request.model_dump = MagicMock(
+ return_value={
+ "class_name": "MyClass",
+ "method_name": "process",
+ "args": [],
+ "kwargs": {},
+ }
+ )
+
+ mock_resource = AsyncMock()
+ mock_resource.run_sync = AsyncMock()
+ mock_resource.run_sync.return_value = MagicMock(error="", output="done")
+
+ mock_registry.get_resource_for_function.return_value = mock_resource
+
+ result = await wrapper.wrap_class_method_execution(original_stub, request)
+
+ assert result == "done"
+ original_stub.assert_not_called()
+ mock_resource.run_sync.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_wrap_class_method_no_class_name(self, wrapper, original_stub):
+ """Test class method with no class_name executes locally."""
+ request = MagicMock()
+ request.class_name = None
+
+ await wrapper.wrap_class_method_execution(original_stub, request)
+
+ original_stub.assert_called_once_with(request)
+
+ @pytest.mark.asyncio
+ async def test_execute_remote_payload_format(self, wrapper, sample_function):
+ """Test that remote payload matches RunPod format."""
+ mock_resource = AsyncMock()
+ mock_resource.run_sync = AsyncMock()
+ mock_resource.run_sync.return_value = MagicMock(error="", output=None)
+
+ with patch("tetra_rp.runtime.serialization.cloudpickle") as mock_pickle:
+ mock_pickle.dumps.return_value = b"pickled"
+
+ await wrapper._execute_remote(
+ mock_resource,
+ "gpu_task",
+ (1, 2),
+ {"key": "value"},
+ execution_type="function",
+ )
+
+ call_args = mock_resource.run_sync.call_args
+ payload = call_args[0][0]
+
+ assert payload["input"]["function_name"] == "gpu_task"
+ assert payload["input"]["execution_type"] == "function"
+ assert len(payload["input"]["args"]) == 2
+ assert "key" in payload["input"]["kwargs"]
+
+ @pytest.mark.asyncio
+ async def test_build_class_payload_dict_request(self, wrapper):
+ """Test building class payload from dict request."""
+ request = {
+ "class_name": "MyClass",
+ "method_name": "process",
+ "args": ["arg1"],
+ "kwargs": {"key": "value"},
+ }
+
+ payload = wrapper._build_class_payload(request)
+
+ assert payload["input"]["function_name"] == "MyClass"
+ assert payload["input"]["execution_type"] == "class"
+ assert payload["input"]["method_name"] == "process"
+
+ @pytest.mark.asyncio
+ async def test_build_class_payload_object_request(self, wrapper):
+ """Test building class payload from object request."""
+ request = MagicMock()
+ request.model_dump.return_value = {
+ "class_name": "MyClass",
+ "method_name": "process",
+ "args": ["arg1"],
+ "kwargs": {"key": "value"},
+ }
+
+ payload = wrapper._build_class_payload(request)
+
+ assert payload["input"]["function_name"] == "MyClass"
+ assert payload["input"]["execution_type"] == "class"
+
+
+class TestCreateProductionWrapper:
+ """Test ProductionWrapper factory function."""
+
+ def teardown_method(self):
+ """Reset wrapper after each test."""
+ reset_wrapper()
+
+ def test_create_wrapper_singleton(self):
+ """Test that create_production_wrapper returns singleton."""
+ wrapper1 = create_production_wrapper()
+ wrapper2 = create_production_wrapper()
+
+ assert wrapper1 is wrapper2
+
+ def test_create_wrapper_with_custom_registry(self):
+ """Test creating wrapper with custom registry."""
+ registry = AsyncMock(spec=ServiceRegistry)
+
+ wrapper = create_production_wrapper(registry)
+
+ assert wrapper.service_registry is registry
+
+ def test_create_wrapper_creates_defaults(self):
+ """Test that wrapper creates default components."""
+ with patch(
+ "tetra_rp.runtime.production_wrapper.ServiceRegistry"
+ ) as mock_registry_class:
+ create_production_wrapper()
+
+ # Should have created ServiceRegistry instance
+ assert mock_registry_class.called
+
+ def test_reset_wrapper(self):
+ """Test resetting wrapper singleton."""
+ wrapper1 = create_production_wrapper()
+ reset_wrapper()
+ wrapper2 = create_production_wrapper()
+
+ assert wrapper1 is not wrapper2
diff --git a/tests/unit/runtime/test_service_registry.py b/tests/unit/runtime/test_service_registry.py
new file mode 100644
index 00000000..8dc88aa1
--- /dev/null
+++ b/tests/unit/runtime/test_service_registry.py
@@ -0,0 +1,297 @@
+"""Tests for ServiceRegistry."""
+
+import json
+import os
+import tempfile
+import time
+from pathlib import Path
+from unittest.mock import AsyncMock, patch
+
+import pytest
+
+from tetra_rp.runtime.service_registry import ServiceRegistry
+
+
+class TestServiceRegistry:
+ """Test ServiceRegistry functionality."""
+
+ @pytest.fixture
+ def manifest_dict(self):
+ """Sample manifest."""
+ return {
+ "version": "1.0",
+ "project_name": "test_app",
+ "function_registry": {
+ "gpu_task": "gpu_config",
+ "preprocess": "cpu_config",
+ "inference": "gpu_config",
+ },
+ "resources": {
+ "gpu_config": {
+ "resource_type": "LiveServerless",
+ "handler_file": "handler_gpu_config.py",
+ "functions": [
+ {"name": "gpu_task", "module": "workers.gpu", "is_async": True},
+ {
+ "name": "inference",
+ "module": "workers.gpu",
+ "is_async": True,
+ },
+ ],
+ },
+ "cpu_config": {
+ "resource_type": "LiveServerless",
+ "handler_file": "handler_cpu_config.py",
+ "functions": [
+ {
+ "name": "preprocess",
+ "module": "workers.cpu",
+ "is_async": False,
+ },
+ ],
+ },
+ },
+ }
+
+ @pytest.fixture
+ def manifest_file(self, manifest_dict):
+ """Create temporary manifest file."""
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
+ json.dump(manifest_dict, f)
+ path = f.name
+
+ yield Path(path)
+
+ # Cleanup
+ Path(path).unlink()
+
+ def test_init_with_manifest_path(self, manifest_file):
+ """Test initialization with explicit manifest path."""
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ assert registry.get_manifest().project_name == "test_app"
+
+ def test_init_from_env_manifest_path(self, manifest_file):
+ """Test initialization from FLASH_MANIFEST_PATH env var."""
+ with patch.dict(os.environ, {"FLASH_MANIFEST_PATH": str(manifest_file)}):
+ registry = ServiceRegistry()
+ assert registry.get_manifest().project_name == "test_app"
+
+ def test_init_manifest_not_found(self):
+ """Test initialization with missing manifest."""
+ with patch.dict(os.environ, {}, clear=True):
+ registry = ServiceRegistry(manifest_path=Path("/nonexistent/manifest.json"))
+ # Should not fail, returns empty manifest
+ assert registry.get_manifest().function_registry == {}
+
+ def test_get_current_endpoint_id(self):
+ """Test retrieval of current endpoint ID from env."""
+ with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu-endpoint-123"}):
+ registry = ServiceRegistry(manifest_path=Path("/nonexistent"))
+ assert registry.get_current_endpoint_id() == "gpu-endpoint-123"
+
+ def test_get_current_endpoint_id_not_set(self):
+ """Test when endpoint ID not set."""
+ with patch.dict(os.environ, {}, clear=True):
+ registry = ServiceRegistry(manifest_path=Path("/nonexistent"))
+ assert registry.get_current_endpoint_id() is None
+
+ def test_is_local_function_local(self, manifest_file):
+ """Test determining local function."""
+ with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu_config"}):
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ assert registry.is_local_function("gpu_task") is True
+ assert registry.is_local_function("inference") is True
+
+ def test_is_local_function_remote(self, manifest_file):
+ """Test determining remote function (with directory loaded)."""
+ with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu_config"}):
+ mock_client = AsyncMock()
+ mock_client.get_directory.return_value = {
+ "cpu_config": "https://cpu.example.com"
+ }
+
+ registry = ServiceRegistry(
+ manifest_path=manifest_file, directory_client=mock_client
+ )
+ # After directory is loaded, CPU tasks should be recognized as remote
+ # (but is_local_function doesn't async load, so returns True for now)
+ # This is actually expected behavior - sync method can't load async directory
+ assert registry.is_local_function("preprocess") is True
+
+ def test_is_local_function_not_in_manifest(self, manifest_file):
+ """Test function not in manifest."""
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ # Unknown function assumed local
+ assert registry.is_local_function("unknown_function") is True
+
+ def test_get_endpoint_for_function_local(self, manifest_file):
+ """Test getting endpoint for local function."""
+ with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu_config"}):
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ endpoint = registry.get_endpoint_for_function("gpu_task")
+ assert endpoint is None # Local returns None
+
+ def test_get_endpoint_for_function_remote_no_directory(self, manifest_file):
+ """Test getting endpoint for remote function without directory."""
+ with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu_config"}):
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ # CPU function is remote, but no directory loaded
+ endpoint = registry.get_endpoint_for_function("preprocess")
+ assert endpoint is None
+
+ def test_get_endpoint_for_function_not_in_manifest(self, manifest_file):
+ """Test getting endpoint for unknown function."""
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ with pytest.raises(ValueError, match="not found in manifest"):
+ registry.get_endpoint_for_function("unknown_function")
+
+ def test_get_resource_for_function_local(self, manifest_file):
+ """Test getting ServerlessResource for local function."""
+ with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu_config"}):
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ resource = registry.get_resource_for_function("gpu_task")
+ # Local function returns None
+ assert resource is None
+
+ def test_get_resource_for_function_remote(self, manifest_file):
+ """Test getting ServerlessResource for remote function."""
+ with patch.dict(os.environ, {"RUNPOD_ENDPOINT_ID": "gpu_config"}):
+ mock_client = AsyncMock()
+ mock_client.get_directory.return_value = {
+ "cpu_config": "https://api.runpod.io/v2/abc123"
+ }
+
+ registry = ServiceRegistry(
+ manifest_path=manifest_file, directory_client=mock_client
+ )
+ # Manually set directory to simulate loaded state
+ registry._directory = {"cpu_config": "https://api.runpod.io/v2/abc123"}
+
+ resource = registry.get_resource_for_function("preprocess")
+
+ # Should return ServerlessResource
+ assert resource is not None
+ assert resource.id == "abc123"
+ # Name starts with remote_preprocess (may have random suffix appended)
+ assert resource.name.startswith("remote_preprocess")
+
+ def test_get_resource_for_function_not_in_manifest(self, manifest_file):
+ """Test getting resource for unknown function."""
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ with pytest.raises(ValueError, match="not found in manifest"):
+ registry.get_resource_for_function("unknown_function")
+
+ @pytest.mark.asyncio
+ async def test_ensure_directory_loaded(self, manifest_file):
+ """Test lazy loading of directory from client."""
+ mock_directory = {
+ "gpu_config": "https://gpu.example.com",
+ "cpu_config": "https://cpu.example.com",
+ }
+
+ mock_client = AsyncMock()
+ mock_client.get_directory.return_value = mock_directory
+
+ registry = ServiceRegistry(
+ manifest_path=manifest_file, directory_client=mock_client, cache_ttl=10
+ )
+
+ # Directory not loaded yet
+ assert registry._directory == {}
+
+ # Load directory
+ await registry._ensure_directory_loaded()
+
+ # Should now have loaded directory
+ assert registry._directory == mock_directory
+ mock_client.get_directory.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_ensure_directory_cache_respects_ttl(self, manifest_file):
+ """Test that directory cache respects TTL."""
+ mock_directory = {"gpu_config": "https://gpu.example.com"}
+
+ mock_client = AsyncMock()
+ mock_client.get_directory.return_value = mock_directory
+
+ registry = ServiceRegistry(
+ manifest_path=manifest_file, directory_client=mock_client, cache_ttl=1
+ )
+
+ # Load directory
+ await registry._ensure_directory_loaded()
+ assert mock_client.get_directory.call_count == 1
+
+ # Immediate reload should use cache
+ await registry._ensure_directory_loaded()
+ assert mock_client.get_directory.call_count == 1
+
+ # After TTL, should reload
+ registry._directory_loaded_at = time.time() - 2 # 2 seconds ago
+ await registry._ensure_directory_loaded()
+ assert mock_client.get_directory.call_count == 2
+
+ @pytest.mark.asyncio
+ async def test_refresh_directory(self, manifest_file):
+ """Test forcing directory refresh."""
+ mock_directory = {"gpu_config": "https://gpu.example.com"}
+
+ mock_client = AsyncMock()
+ mock_client.get_directory.return_value = mock_directory
+
+ registry = ServiceRegistry(
+ manifest_path=manifest_file, directory_client=mock_client, cache_ttl=3600
+ )
+
+ # Load directory
+ await registry._ensure_directory_loaded()
+ assert mock_client.get_directory.call_count == 1
+
+ # Force refresh
+ registry.refresh_directory()
+
+ # Next load should fetch again
+ await registry._ensure_directory_loaded()
+ assert mock_client.get_directory.call_count == 2
+
+ def test_get_manifest(self, manifest_file):
+ """Test getting manifest."""
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ manifest = registry.get_manifest()
+ assert manifest.project_name == "test_app"
+
+ def test_get_all_resources(self, manifest_file):
+ """Test getting all resources."""
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ resources = registry.get_all_resources()
+ assert "gpu_config" in resources
+ assert "cpu_config" in resources
+
+ def test_get_resource_functions(self, manifest_file):
+ """Test getting functions for a resource."""
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ functions = registry.get_resource_functions("gpu_config")
+ assert len(functions) == 2
+ names = [f["name"] for f in functions]
+ assert "gpu_task" in names
+ assert "inference" in names
+
+ def test_get_resource_functions_not_found(self, manifest_file):
+ """Test getting functions for nonexistent resource."""
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ functions = registry.get_resource_functions("nonexistent")
+ assert functions == []
+
+ def test_init_no_directory_client_no_mothership_url(self, manifest_file):
+ """Test initialization without directory client or URL."""
+ with patch.dict(os.environ, {}, clear=True):
+ registry = ServiceRegistry(manifest_path=manifest_file)
+ assert registry._directory_client is None
+
+ @pytest.mark.asyncio
+ async def test_ensure_directory_loaded_unavailable_client(self, manifest_file):
+ """Test directory loading when client is None."""
+ registry = ServiceRegistry(manifest_path=manifest_file, directory_client=None)
+ # Should not fail, just log warning
+ await registry._ensure_directory_loaded()
+ assert registry._directory == {}