Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"cloudpickle>=3.1.1",
"runpod",
"python-dotenv>=1.0.0",
"pydantic>=2.0.0",
]

[dependency-groups]
Expand Down
28 changes: 25 additions & 3 deletions src/tetra_rp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def remote(
resource_config: ServerlessResource,
dependencies: Optional[List[str]] = None,
system_dependencies: Optional[List[str]] = None,
accelerate_downloads: bool = True,
hf_models_to_cache: Optional[List[str]] = None,
**extra,
):
"""
Expand All @@ -22,10 +24,17 @@ def remote(
This decorator allows a function to be executed in a remote serverless environment, with support for
dynamic resource provisioning and installation of required dependencies.

Args:
resource_config (ServerlessResource): Configuration object specifying the serverless resource
to be provisioned or used.
dependencies (List[str], optional): A list of pip package names to be installed in the remote
environment before executing the function. Defaults to None.
system_dependencies (List[str], optional): A list of system packages to be installed in the remote
environment before executing the function. Defaults to None.
accelerate_downloads (bool, optional): Enable download acceleration for dependencies and models.
Defaults to True.
hf_models_to_cache (List[str], optional): List of HuggingFace model IDs to pre-cache using
download acceleration. Defaults to None.
extra (dict, optional): Additional parameters for the execution of the resource. Defaults to an empty dict.

Returns:
Expand All @@ -37,7 +46,8 @@ def remote(
@remote(
resource_config=my_resource_config,
dependencies=["numpy", "pandas"],
sync=True # Optional, to run synchronously
accelerate_downloads=True,
hf_models_to_cache=["gpt2", "bert-base-uncased"]
)
async def my_function(data):
# Function logic here
Expand All @@ -49,7 +59,13 @@ def decorator(func_or_class):
if inspect.isclass(func_or_class):
# Handle class decoration
return create_remote_class(
func_or_class, resource_config, dependencies, system_dependencies, extra
func_or_class,
resource_config,
dependencies,
system_dependencies,
accelerate_downloads,
hf_models_to_cache,
extra,
)
else:
# Handle function decoration (unchanged)
Expand All @@ -62,7 +78,13 @@ async def wrapper(*args, **kwargs):

stub = stub_resource(remote_resource, **extra)
return await stub(
func_or_class, dependencies, system_dependencies, *args, **kwargs
func_or_class,
dependencies,
system_dependencies,
accelerate_downloads,
hf_models_to_cache,
*args,
**kwargs,
)

return wrapper
Expand Down
6 changes: 6 additions & 0 deletions src/tetra_rp/execute_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def create_remote_class(
resource_config: ServerlessResource,
dependencies: Optional[List[str]],
system_dependencies: Optional[List[str]],
accelerate_downloads: bool,
hf_models_to_cache: Optional[List[str]],
extra: dict,
):
"""
Expand All @@ -219,6 +221,8 @@ def __init__(self, *args, **kwargs):
self._resource_config = resource_config
self._dependencies = dependencies or []
self._system_dependencies = system_dependencies or []
self._accelerate_downloads = accelerate_downloads
self._hf_models_to_cache = hf_models_to_cache
self._extra = extra
self._constructor_args = args
self._constructor_kwargs = kwargs
Expand Down Expand Up @@ -302,6 +306,8 @@ async def method_proxy(*args, **kwargs):
constructor_kwargs=constructor_kwargs,
dependencies=self._dependencies,
system_dependencies=self._system_dependencies,
accelerate_downloads=self._accelerate_downloads,
hf_models_to_cache=self._hf_models_to_cache,
instance_id=self._instance_id,
create_new_instance=not hasattr(
self, "_stub"
Expand Down
4 changes: 4 additions & 0 deletions src/tetra_rp/protos/remote_execution.proto
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ message FunctionRequest {
string method_name = 12; // Name of the method to call on the class instance (default: "__call__")
optional string instance_id = 13; // Unique identifier for the class instance (for persistence)
bool create_new_instance = 14; // Whether to create a new instance or reuse existing one

// Download acceleration fields
optional bool accelerate_downloads = 19; // Enable download acceleration for dependencies and models (default: true)
repeated string hf_models_to_cache = 20; // List of HuggingFace model IDs to pre-cache using acceleration
}

// The response message containing the execution result or error
Expand Down
48 changes: 36 additions & 12 deletions src/tetra_rp/protos/remote_execution.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
# TODO: generate using betterproto
"""Remote execution protocol definitions using Pydantic models.

This module defines the request/response protocol for remote function and class execution.
The models align with the protobuf schema for communication with remote workers.
"""

from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field, model_validator


class FunctionRequest(BaseModel):
"""Request model for remote function or class execution.

Supports both function-based execution and class instantiation with method calls.
All serialized data (args, kwargs, etc.) are base64-encoded cloudpickle strings.
"""

# MADE OPTIONAL - can be None for class-only execution
function_name: Optional[str] = Field(
default=None,
Expand All @@ -15,19 +26,19 @@ class FunctionRequest(BaseModel):
default=None,
description="Source code of the function to execute",
)
args: List = Field(
args: List[str] = Field(
default_factory=list,
description="List of base64-encoded cloudpickle-serialized arguments",
)
kwargs: Dict = Field(
kwargs: Dict[str, str] = Field(
default_factory=dict,
description="Dictionary of base64-encoded cloudpickle-serialized keyword arguments",
)
dependencies: Optional[List] = Field(
dependencies: Optional[List[str]] = Field(
default=None,
description="Optional list of pip packages to install before executing the function",
)
system_dependencies: Optional[List] = Field(
system_dependencies: Optional[List[str]] = Field(
default=None,
description="Optional list of system dependencies to install before executing the function",
)
Expand All @@ -44,11 +55,11 @@ class FunctionRequest(BaseModel):
default=None,
description="Source code of the class to instantiate (for class execution)",
)
constructor_args: Optional[List] = Field(
constructor_args: List[str] = Field(
default_factory=list,
description="List of base64-encoded cloudpickle-serialized constructor arguments",
)
constructor_kwargs: Optional[Dict] = Field(
constructor_kwargs: Dict[str, str] = Field(
default_factory=dict,
description="Dictionary of base64-encoded cloudpickle-serialized constructor keyword arguments",
)
Expand All @@ -65,6 +76,16 @@ class FunctionRequest(BaseModel):
description="Whether to create a new instance or reuse existing one",
)

# Download acceleration fields
accelerate_downloads: bool = Field(
default=True,
description="Enable download acceleration for dependencies and models",
)
hf_models_to_cache: Optional[List[str]] = Field(
default=None,
description="List of HuggingFace model IDs to pre-cache using acceleration",
)

@model_validator(mode="after")
def validate_execution_requirements(self) -> "FunctionRequest":
"""Validate that required fields are provided based on execution_type"""
Expand Down Expand Up @@ -92,7 +113,12 @@ def validate_execution_requirements(self) -> "FunctionRequest":


class FunctionResponse(BaseModel):
# EXISTING FIELDS (unchanged)
"""Response model for remote function or class execution results.

Contains execution results, error information, and metadata about class instances
when applicable. The result field contains base64-encoded cloudpickle data.
"""

success: bool = Field(
description="Indicates if the function execution was successful",
)
Expand All @@ -108,12 +134,10 @@ class FunctionResponse(BaseModel):
default=None,
description="Captured standard output from the function execution",
)

# NEW FIELDS FOR CLASS SUPPORT
instance_id: Optional[str] = Field(
default=None, description="ID of the class instance that was used/created"
)
instance_info: Optional[Dict] = Field(
instance_info: Optional[Dict[str, Any]] = Field(
default=None,
description="Metadata about the class instance (creation time, call count, etc.)",
)
Expand Down
13 changes: 12 additions & 1 deletion src/tetra_rp/stubs/live_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,24 @@ class LiveServerlessStub(RemoteExecutorStub):
def __init__(self, server: LiveServerless):
self.server = server

def prepare_request(self, func, dependencies, system_dependencies, *args, **kwargs):
def prepare_request(
self,
func,
dependencies,
system_dependencies,
accelerate_downloads,
hf_models_to_cache,
*args,
**kwargs,
):
source, src_hash = get_function_source(func)

request = {
"function_name": func.__name__,
"dependencies": dependencies,
"system_dependencies": system_dependencies,
"accelerate_downloads": accelerate_downloads,
"hf_models_to_cache": hf_models_to_cache,
}

# check if the function is already cached
Expand Down
16 changes: 14 additions & 2 deletions src/tetra_rp/stubs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,25 @@ def _(resource, **extra):

# Function execution
async def stubbed_resource(
func, dependencies, system_dependencies, *args, **kwargs
func,
dependencies,
system_dependencies,
accelerate_downloads,
hf_models_to_cache,
*args,
**kwargs,
) -> dict:
if args == (None,):
args = []

request = stub.prepare_request(
func, dependencies, system_dependencies, *args, **kwargs
func,
dependencies,
system_dependencies,
accelerate_downloads,
hf_models_to_cache,
*args,
**kwargs,
)
response = await stub.ExecuteFunction(request)
return stub.handle_response(response)
Expand Down
28 changes: 19 additions & 9 deletions tests/integration/test_class_execution_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def get_state(self):
}

RemoteCounter = create_remote_class(
StatefulCounter, self.mock_resource_config, [], [], {}
StatefulCounter, self.mock_resource_config, [], [], True, None, {}
)

counter = RemoteCounter(5)
Expand Down Expand Up @@ -276,7 +276,7 @@ def get_completed_count(self):
return self.tasks_completed

RemoteWorker = create_remote_class(
AsyncWorker, self.mock_resource_config, [], [], {}
AsyncWorker, self.mock_resource_config, [], [], True, None, {}
)

worker = RemoteWorker()
Expand Down Expand Up @@ -374,8 +374,10 @@ def process_with_config(self, input_data):
ConfigurableModel,
self.mock_resource_config,
["scikit-learn", "pandas"],
[],
{},
[], # system_dependencies
True, # accelerate_downloads
None, # hf_models_to_cache
{}, # extra
)

model = RemoteModel(
Expand Down Expand Up @@ -476,7 +478,7 @@ def get_service_info(self):
api_keys = ["key1", "key2", "key3"]

RemoteDataService = create_remote_class(
DataService, self.mock_resource_config, ["psycopg2"], [], {}
DataService, self.mock_resource_config, ["psycopg2"], [], True, None, {}
)

service = RemoteDataService(db_conn, cache_conf, api_keys=api_keys)
Expand Down Expand Up @@ -547,7 +549,7 @@ def safe_method(self):
return "This always works"

RemoteErrorProneClass = create_remote_class(
ErrorProneClass, self.mock_resource_config, [], [], {}
ErrorProneClass, self.mock_resource_config, [], [], True, None, {}
)

error_instance = RemoteErrorProneClass(should_fail=True)
Expand Down Expand Up @@ -583,7 +585,7 @@ def simple_method(self):
return "hello"

RemoteSimpleClass = create_remote_class(
SimpleClass, self.mock_resource_config, [], [], {}
SimpleClass, self.mock_resource_config, [], [], True, None, {}
)

instance = RemoteSimpleClass()
Expand Down Expand Up @@ -619,7 +621,7 @@ def process_file(self):

with tempfile.NamedTemporaryFile() as temp_file:
RemoteUnserializableClass = create_remote_class(
UnserializableClass, self.mock_resource_config, [], [], {}
UnserializableClass, self.mock_resource_config, [], [], True, None, {}
)

# This should not fail during initialization (lazy serialization)
Expand Down Expand Up @@ -666,6 +668,8 @@ def slow_method(self, duration):
self.mock_resource_config,
[],
[],
True,
None,
{"timeout": 5}, # 5 second timeout
)

Expand Down Expand Up @@ -700,6 +704,8 @@ def test_invalid_class_type_error(self):
self.mock_resource_config,
[],
[],
True,
None,
{},
)

Expand All @@ -708,7 +714,9 @@ def not_a_class():
pass

with pytest.raises(TypeError, match="Expected a class"):
create_remote_class(not_a_class, self.mock_resource_config, [], [], {})
create_remote_class(
not_a_class, self.mock_resource_config, [], [], True, None, {}
)

# Note: Testing class without __name__ is not practically possible
# since Python classes always have __name__ attribute
Expand All @@ -729,6 +737,8 @@ def use_dependency(self):
self.mock_resource_config,
["nonexistent-package==999.999.999"], # Invalid package
[],
True,
None,
{},
)

Expand Down
Loading
Loading