diff --git a/pyproject.toml b/pyproject.toml index c19a1bf3..5403e38b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "cloudpickle>=3.1.1", "runpod", "python-dotenv>=1.0.0", + "pydantic>=2.0.0", ] [dependency-groups] diff --git a/src/tetra_rp/client.py b/src/tetra_rp/client.py index dd086a7c..5c2fc4ba 100644 --- a/src/tetra_rp/client.py +++ b/src/tetra_rp/client.py @@ -14,6 +14,8 @@ def remote( resource_config: ServerlessResource, dependencies: Optional[List[str]] = None, system_dependencies: Optional[List[str]] = None, + accelerate_downloads: bool = True, + hf_models_to_cache: Optional[List[str]] = None, **extra, ): """ @@ -22,10 +24,17 @@ def remote( This decorator allows a function to be executed in a remote serverless environment, with support for dynamic resource provisioning and installation of required dependencies. + Args: resource_config (ServerlessResource): Configuration object specifying the serverless resource to be provisioned or used. dependencies (List[str], optional): A list of pip package names to be installed in the remote environment before executing the function. Defaults to None. + system_dependencies (List[str], optional): A list of system packages to be installed in the remote + environment before executing the function. Defaults to None. + accelerate_downloads (bool, optional): Enable download acceleration for dependencies and models. + Defaults to True. + hf_models_to_cache (List[str], optional): List of HuggingFace model IDs to pre-cache using + download acceleration. Defaults to None. extra (dict, optional): Additional parameters for the execution of the resource. Defaults to an empty dict. Returns: @@ -37,7 +46,8 @@ def remote( @remote( resource_config=my_resource_config, dependencies=["numpy", "pandas"], - sync=True # Optional, to run synchronously + accelerate_downloads=True, + hf_models_to_cache=["gpt2", "bert-base-uncased"] ) async def my_function(data): # Function logic here @@ -49,7 +59,13 @@ def decorator(func_or_class): if inspect.isclass(func_or_class): # Handle class decoration return create_remote_class( - func_or_class, resource_config, dependencies, system_dependencies, extra + func_or_class, + resource_config, + dependencies, + system_dependencies, + accelerate_downloads, + hf_models_to_cache, + extra, ) else: # Handle function decoration (unchanged) @@ -62,7 +78,13 @@ async def wrapper(*args, **kwargs): stub = stub_resource(remote_resource, **extra) return await stub( - func_or_class, dependencies, system_dependencies, *args, **kwargs + func_or_class, + dependencies, + system_dependencies, + accelerate_downloads, + hf_models_to_cache, + *args, + **kwargs, ) return wrapper diff --git a/src/tetra_rp/execute_class.py b/src/tetra_rp/execute_class.py index 6289a02c..6ffa4849 100644 --- a/src/tetra_rp/execute_class.py +++ b/src/tetra_rp/execute_class.py @@ -202,6 +202,8 @@ def create_remote_class( resource_config: ServerlessResource, dependencies: Optional[List[str]], system_dependencies: Optional[List[str]], + accelerate_downloads: bool, + hf_models_to_cache: Optional[List[str]], extra: dict, ): """ @@ -219,6 +221,8 @@ def __init__(self, *args, **kwargs): self._resource_config = resource_config self._dependencies = dependencies or [] self._system_dependencies = system_dependencies or [] + self._accelerate_downloads = accelerate_downloads + self._hf_models_to_cache = hf_models_to_cache self._extra = extra self._constructor_args = args self._constructor_kwargs = kwargs @@ -302,6 +306,8 @@ async def method_proxy(*args, **kwargs): constructor_kwargs=constructor_kwargs, dependencies=self._dependencies, system_dependencies=self._system_dependencies, + accelerate_downloads=self._accelerate_downloads, + hf_models_to_cache=self._hf_models_to_cache, instance_id=self._instance_id, create_new_instance=not hasattr( self, "_stub" diff --git a/src/tetra_rp/protos/remote_execution.proto b/src/tetra_rp/protos/remote_execution.proto index 39341dec..c328576c 100644 --- a/src/tetra_rp/protos/remote_execution.proto +++ b/src/tetra_rp/protos/remote_execution.proto @@ -25,6 +25,10 @@ message FunctionRequest { string method_name = 12; // Name of the method to call on the class instance (default: "__call__") optional string instance_id = 13; // Unique identifier for the class instance (for persistence) bool create_new_instance = 14; // Whether to create a new instance or reuse existing one + + // Download acceleration fields + optional bool accelerate_downloads = 19; // Enable download acceleration for dependencies and models (default: true) + repeated string hf_models_to_cache = 20; // List of HuggingFace model IDs to pre-cache using acceleration } // The response message containing the execution result or error diff --git a/src/tetra_rp/protos/remote_execution.py b/src/tetra_rp/protos/remote_execution.py index 6fc80dd4..3226faf7 100644 --- a/src/tetra_rp/protos/remote_execution.py +++ b/src/tetra_rp/protos/remote_execution.py @@ -1,11 +1,22 @@ -# TODO: generate using betterproto +"""Remote execution protocol definitions using Pydantic models. + +This module defines the request/response protocol for remote function and class execution. +The models align with the protobuf schema for communication with remote workers. +""" + from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, model_validator class FunctionRequest(BaseModel): + """Request model for remote function or class execution. + + Supports both function-based execution and class instantiation with method calls. + All serialized data (args, kwargs, etc.) are base64-encoded cloudpickle strings. + """ + # MADE OPTIONAL - can be None for class-only execution function_name: Optional[str] = Field( default=None, @@ -15,19 +26,19 @@ class FunctionRequest(BaseModel): default=None, description="Source code of the function to execute", ) - args: List = Field( + args: List[str] = Field( default_factory=list, description="List of base64-encoded cloudpickle-serialized arguments", ) - kwargs: Dict = Field( + kwargs: Dict[str, str] = Field( default_factory=dict, description="Dictionary of base64-encoded cloudpickle-serialized keyword arguments", ) - dependencies: Optional[List] = Field( + dependencies: Optional[List[str]] = Field( default=None, description="Optional list of pip packages to install before executing the function", ) - system_dependencies: Optional[List] = Field( + system_dependencies: Optional[List[str]] = Field( default=None, description="Optional list of system dependencies to install before executing the function", ) @@ -44,11 +55,11 @@ class FunctionRequest(BaseModel): default=None, description="Source code of the class to instantiate (for class execution)", ) - constructor_args: Optional[List] = Field( + constructor_args: List[str] = Field( default_factory=list, description="List of base64-encoded cloudpickle-serialized constructor arguments", ) - constructor_kwargs: Optional[Dict] = Field( + constructor_kwargs: Dict[str, str] = Field( default_factory=dict, description="Dictionary of base64-encoded cloudpickle-serialized constructor keyword arguments", ) @@ -65,6 +76,16 @@ class FunctionRequest(BaseModel): description="Whether to create a new instance or reuse existing one", ) + # Download acceleration fields + accelerate_downloads: bool = Field( + default=True, + description="Enable download acceleration for dependencies and models", + ) + hf_models_to_cache: Optional[List[str]] = Field( + default=None, + description="List of HuggingFace model IDs to pre-cache using acceleration", + ) + @model_validator(mode="after") def validate_execution_requirements(self) -> "FunctionRequest": """Validate that required fields are provided based on execution_type""" @@ -92,7 +113,12 @@ def validate_execution_requirements(self) -> "FunctionRequest": class FunctionResponse(BaseModel): - # EXISTING FIELDS (unchanged) + """Response model for remote function or class execution results. + + Contains execution results, error information, and metadata about class instances + when applicable. The result field contains base64-encoded cloudpickle data. + """ + success: bool = Field( description="Indicates if the function execution was successful", ) @@ -108,12 +134,10 @@ class FunctionResponse(BaseModel): default=None, description="Captured standard output from the function execution", ) - - # NEW FIELDS FOR CLASS SUPPORT instance_id: Optional[str] = Field( default=None, description="ID of the class instance that was used/created" ) - instance_info: Optional[Dict] = Field( + instance_info: Optional[Dict[str, Any]] = Field( default=None, description="Metadata about the class instance (creation time, call count, etc.)", ) diff --git a/src/tetra_rp/stubs/live_serverless.py b/src/tetra_rp/stubs/live_serverless.py index c8bbf672..1705530d 100644 --- a/src/tetra_rp/stubs/live_serverless.py +++ b/src/tetra_rp/stubs/live_serverless.py @@ -60,13 +60,24 @@ class LiveServerlessStub(RemoteExecutorStub): def __init__(self, server: LiveServerless): self.server = server - def prepare_request(self, func, dependencies, system_dependencies, *args, **kwargs): + def prepare_request( + self, + func, + dependencies, + system_dependencies, + accelerate_downloads, + hf_models_to_cache, + *args, + **kwargs, + ): source, src_hash = get_function_source(func) request = { "function_name": func.__name__, "dependencies": dependencies, "system_dependencies": system_dependencies, + "accelerate_downloads": accelerate_downloads, + "hf_models_to_cache": hf_models_to_cache, } # check if the function is already cached diff --git a/src/tetra_rp/stubs/registry.py b/src/tetra_rp/stubs/registry.py index 778ea3d7..261dcbff 100644 --- a/src/tetra_rp/stubs/registry.py +++ b/src/tetra_rp/stubs/registry.py @@ -26,13 +26,25 @@ def _(resource, **extra): # Function execution async def stubbed_resource( - func, dependencies, system_dependencies, *args, **kwargs + func, + dependencies, + system_dependencies, + accelerate_downloads, + hf_models_to_cache, + *args, + **kwargs, ) -> dict: if args == (None,): args = [] request = stub.prepare_request( - func, dependencies, system_dependencies, *args, **kwargs + func, + dependencies, + system_dependencies, + accelerate_downloads, + hf_models_to_cache, + *args, + **kwargs, ) response = await stub.ExecuteFunction(request) return stub.handle_response(response) diff --git a/tests/integration/test_class_execution_integration.py b/tests/integration/test_class_execution_integration.py index 545e8923..93b3052b 100644 --- a/tests/integration/test_class_execution_integration.py +++ b/tests/integration/test_class_execution_integration.py @@ -204,7 +204,7 @@ def get_state(self): } RemoteCounter = create_remote_class( - StatefulCounter, self.mock_resource_config, [], [], {} + StatefulCounter, self.mock_resource_config, [], [], True, None, {} ) counter = RemoteCounter(5) @@ -276,7 +276,7 @@ def get_completed_count(self): return self.tasks_completed RemoteWorker = create_remote_class( - AsyncWorker, self.mock_resource_config, [], [], {} + AsyncWorker, self.mock_resource_config, [], [], True, None, {} ) worker = RemoteWorker() @@ -374,8 +374,10 @@ def process_with_config(self, input_data): ConfigurableModel, self.mock_resource_config, ["scikit-learn", "pandas"], - [], - {}, + [], # system_dependencies + True, # accelerate_downloads + None, # hf_models_to_cache + {}, # extra ) model = RemoteModel( @@ -476,7 +478,7 @@ def get_service_info(self): api_keys = ["key1", "key2", "key3"] RemoteDataService = create_remote_class( - DataService, self.mock_resource_config, ["psycopg2"], [], {} + DataService, self.mock_resource_config, ["psycopg2"], [], True, None, {} ) service = RemoteDataService(db_conn, cache_conf, api_keys=api_keys) @@ -547,7 +549,7 @@ def safe_method(self): return "This always works" RemoteErrorProneClass = create_remote_class( - ErrorProneClass, self.mock_resource_config, [], [], {} + ErrorProneClass, self.mock_resource_config, [], [], True, None, {} ) error_instance = RemoteErrorProneClass(should_fail=True) @@ -583,7 +585,7 @@ def simple_method(self): return "hello" RemoteSimpleClass = create_remote_class( - SimpleClass, self.mock_resource_config, [], [], {} + SimpleClass, self.mock_resource_config, [], [], True, None, {} ) instance = RemoteSimpleClass() @@ -619,7 +621,7 @@ def process_file(self): with tempfile.NamedTemporaryFile() as temp_file: RemoteUnserializableClass = create_remote_class( - UnserializableClass, self.mock_resource_config, [], [], {} + UnserializableClass, self.mock_resource_config, [], [], True, None, {} ) # This should not fail during initialization (lazy serialization) @@ -666,6 +668,8 @@ def slow_method(self, duration): self.mock_resource_config, [], [], + True, + None, {"timeout": 5}, # 5 second timeout ) @@ -700,6 +704,8 @@ def test_invalid_class_type_error(self): self.mock_resource_config, [], [], + True, + None, {}, ) @@ -708,7 +714,9 @@ def not_a_class(): pass with pytest.raises(TypeError, match="Expected a class"): - create_remote_class(not_a_class, self.mock_resource_config, [], [], {}) + create_remote_class( + not_a_class, self.mock_resource_config, [], [], True, None, {} + ) # Note: Testing class without __name__ is not practically possible # since Python classes always have __name__ attribute @@ -729,6 +737,8 @@ def use_dependency(self): self.mock_resource_config, ["nonexistent-package==999.999.999"], # Invalid package [], + True, + None, {}, ) diff --git a/tests/unit/test_class_caching.py b/tests/unit/test_class_caching.py index 8152e889..920bb908 100644 --- a/tests/unit/test_class_caching.py +++ b/tests/unit/test_class_caching.py @@ -143,7 +143,7 @@ def __init__(self, value): self.value = value RemoteCacheTestClass = create_remote_class( - CacheTestClass, self.mock_resource_config, [], [], {} + CacheTestClass, self.mock_resource_config, [], [], True, None, {} ) # First instance - should be cache miss @@ -177,7 +177,7 @@ def __init__(self, x, y=None): self.y = y RemoteMultiArgClass = create_remote_class( - MultiArgClass, self.mock_resource_config, [], [], {} + MultiArgClass, self.mock_resource_config, [], [], True, None, {} ) # Different args should create different cache entries @@ -198,7 +198,7 @@ def __init__(self, file_handle, name="default"): self.name = name RemoteFileHandlerClass = create_remote_class( - FileHandlerClass, self.mock_resource_config, [], [], {} + FileHandlerClass, self.mock_resource_config, [], [], True, None, {} ) with tempfile.NamedTemporaryFile() as temp_file: @@ -224,7 +224,7 @@ def __init__(self, value): self.value = value RemoteOptimizationTestClass = create_remote_class( - OptimizationTestClass, self.mock_resource_config, [], [], {} + OptimizationTestClass, self.mock_resource_config, [], [], True, None, {} ) with patch("tetra_rp.execute_class.extract_class_code_simple") as mock_extract: @@ -250,7 +250,7 @@ def get_value(self): return self.value RemoteConsistencyTestClass = create_remote_class( - ConsistencyTestClass, self.mock_resource_config, [], [], {} + ConsistencyTestClass, self.mock_resource_config, [], [], True, None, {} ) instance1 = RemoteConsistencyTestClass(1) @@ -273,7 +273,7 @@ def __init__(self, file_handle): self.file_handle = file_handle RemoteUUIDFallbackClass = create_remote_class( - UUIDFallbackClass, self.mock_resource_config, [], [], {} + UUIDFallbackClass, self.mock_resource_config, [], [], True, None, {} ) with ( @@ -299,7 +299,7 @@ def __init__(self, value): self.value = value RemoteMemoryTestClass = create_remote_class( - MemoryTestClass, self.mock_resource_config, [], [], {} + MemoryTestClass, self.mock_resource_config, [], [], True, None, {} ) # Create many instances with same args - should only create one cache entry @@ -323,10 +323,10 @@ def __init__(self, value): self.value = value RemoteClassTypeA = create_remote_class( - ClassTypeA, self.mock_resource_config, [], [], {} + ClassTypeA, self.mock_resource_config, [], [], True, None, {} ) RemoteClassTypeB = create_remote_class( - ClassTypeB, self.mock_resource_config, [], [], {} + ClassTypeB, self.mock_resource_config, [], [], True, None, {} ) instanceA = RemoteClassTypeA(42) @@ -358,7 +358,7 @@ def __init__(self, value, config=None): ) RemoteStructureTestClass = create_remote_class( - StructureTestClass, resource_config, [], [], {} + StructureTestClass, resource_config, [], [], True, None, {} ) instance = RemoteStructureTestClass(42, config={"key": "value"}) @@ -401,7 +401,7 @@ def __init__(self, data): ) RemoteSerializationTestClass = create_remote_class( - SerializationTestClass, resource_config, [], [], {} + SerializationTestClass, resource_config, [], [], True, None, {} ) test_data = {"test": [1, 2, 3]} diff --git a/tests/unit/test_execute_class.py b/tests/unit/test_execute_class.py index 9e28711b..5c0bbe86 100644 --- a/tests/unit/test_execute_class.py +++ b/tests/unit/test_execute_class.py @@ -243,6 +243,8 @@ def get_value(self): self.mock_resource_config, self.dependencies, self.system_dependencies, + True, + None, self.extra, ) @@ -263,6 +265,8 @@ def __init__(self, value, name="default"): self.mock_resource_config, self.dependencies, self.system_dependencies, + True, + None, self.extra, ) @@ -290,6 +294,8 @@ class TestClass: self.mock_resource_config, None, # dependencies None, # system_dependencies + True, # accelerate_downloads + None, # hf_models_to_cache self.extra, ) @@ -312,6 +318,8 @@ class TestClass: self.mock_resource_config, self.dependencies, self.system_dependencies, + True, + None, self.extra, ) @@ -347,6 +355,8 @@ class TestClass: self.mock_resource_config, self.dependencies, self.system_dependencies, + True, + None, self.extra, ) @@ -387,6 +397,8 @@ class TestClass: self.mock_resource_config, self.dependencies, self.system_dependencies, + True, + None, self.extra, ) @@ -417,6 +429,8 @@ def add(self, x, y=10): self.mock_resource_config, self.dependencies, self.system_dependencies, + True, + None, self.extra, ) @@ -481,6 +495,8 @@ def method2(self): self.mock_resource_config, self.dependencies, self.system_dependencies, + True, + None, self.extra, ) @@ -520,7 +536,7 @@ def simple_method(self): return "simple" RemoteWrapper = create_remote_class( - TestClass, self.mock_resource_config, [], [], {} + TestClass, self.mock_resource_config, [], [], True, None, {} ) instance = RemoteWrapper() @@ -562,6 +578,8 @@ def test_method(self): self.mock_resource_config, self.dependencies, self.system_dependencies, + True, + None, self.extra, ) @@ -586,6 +604,8 @@ class TestClass: self.mock_resource_config, self.dependencies, self.system_dependencies, + True, + None, self.extra, ) @@ -632,7 +652,7 @@ def get_value(self): ) RemoteCalculator = create_remote_class( - CalculatorClass, resource_config, ["numpy"], [], {"timeout": 60} + CalculatorClass, resource_config, ["numpy"], [], True, None, {"timeout": 60} ) calculator = RemoteCalculator(10) @@ -680,9 +700,11 @@ def complex_method( RemoteWrapper = create_remote_class( ComplexClass, ServerlessResource(name="test", image="test:latest", cpu=1, memory=256), - [], - [], - {}, + [], # dependencies + [], # system_dependencies + True, # accelerate_downloads + None, # hf_models_to_cache + {}, # extra ) instance = RemoteWrapper("test", extra_arg=True) diff --git a/uv.lock b/uv.lock index b7c8a231..4929c872 100644 --- a/uv.lock +++ b/uv.lock @@ -2345,10 +2345,11 @@ wheels = [ [[package]] name = "tetra-rp" -version = "0.9.0" +version = "0.10.0" source = { editable = "." } dependencies = [ { name = "cloudpickle" }, + { name = "pydantic" }, { name = "python-dotenv" }, { name = "runpod" }, ] @@ -2369,6 +2370,7 @@ test = [ [package.metadata] requires-dist = [ { name = "cloudpickle", specifier = ">=3.1.1" }, + { name = "pydantic", specifier = ">=2.0.0" }, { name = "python-dotenv", specifier = ">=1.0.0" }, { name = "runpod" }, ]