diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 8c1fd6db1b..f9efe42a18 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -70,7 +70,7 @@ def build(self): _cast_to_compatible_version, _detect_framework_and_version, auto_detect_container, - _get_model_base + _get_model_base, ) from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils import task @@ -131,7 +131,6 @@ def build(self): from sagemaker.serve.detector.pickler import save_pkl from sagemaker.serve.builder.requirements_manager import RequirementsManager from sagemaker.serve.validations.check_integrity import ( - generate_secret_key, compute_hash, ) from sagemaker.core.remote_function.core.serialization import _MetaData @@ -224,19 +223,20 @@ def serialize(self, data): return super().serialize(payload) + class _ModelBuilderUtils: """Utility mixin class providing common functionality for ModelBuilder. - + This class provides utility methods for: - Session management and initialization - - Instance type detection and optimization + - Instance type detection and optimization - Container image auto-detection - HuggingFace and JumpStart model handling - Resource requirement calculation - Framework serialization support - MLflow model integration - General model deployment utilities - + This class is designed to be used as a mixin with ModelBuilder classes. It expects certain attributes to be available on the instance: - sagemaker_session: SageMaker session object @@ -244,14 +244,14 @@ class _ModelBuilderUtils: - instance_type: EC2 instance type - region: AWS region - env_vars: Environment variables dict - + Example: class MyModelBuilder(ModelBuilderUtils): def __init__(self): self.model = "huggingface-model-id" self.instance_type = "ml.g5.xlarge" self.sagemaker_session = None - + def build(self): self._init_sagemaker_session_if_does_not_exist() self._auto_detect_image_uri() @@ -262,7 +262,9 @@ def build(self): # Session Management # ======================================== - def _init_sagemaker_session_if_does_not_exist(self, instance_type: Optional[str] = None) -> None: + def _init_sagemaker_session_if_does_not_exist( + self, instance_type: Optional[str] = None + ) -> None: """Initialize SageMaker session if it doesn't exist. Sets self.sagemaker_session to LocalSession for local instances, @@ -275,24 +277,25 @@ def _init_sagemaker_session_if_does_not_exist(self, instance_type: Optional[str] if self.sagemaker_session: return - effective_instance_type = instance_type or getattr(self, 'instance_type', None) - + effective_instance_type = instance_type or getattr(self, "instance_type", None) + if effective_instance_type in ("local", "local_gpu"): self.sagemaker_session = LocalSession( - sagemaker_config=getattr(self, '_sagemaker_config', None) + sagemaker_config=getattr(self, "_sagemaker_config", None) ) else: # Create session with correct region - if hasattr(self, 'region') and self.region: + if hasattr(self, "region") and self.region: import boto3 + boto_session = boto3.Session(region_name=self.region) self.sagemaker_session = Session( boto_session=boto_session, - sagemaker_config=getattr(self, '_sagemaker_config', None) + sagemaker_config=getattr(self, "_sagemaker_config", None), ) else: self.sagemaker_session = Session( - sagemaker_config=getattr(self, '_sagemaker_config', None) + sagemaker_config=getattr(self, "_sagemaker_config", None) ) # ======================================== @@ -301,98 +304,100 @@ def _init_sagemaker_session_if_does_not_exist(self, instance_type: Optional[str] def _get_jumpstart_recommended_instance_type(self) -> Optional[str]: """Get recommended instance type from JumpStart metadata. - + Returns: Recommended instance type string, or None if not available. """ try: deploy_kwargs = get_deploy_kwargs( model_id=self.model, - model_version=getattr(self, 'model_version', None) or "*", + model_version=getattr(self, "model_version", None) or "*", region=self.region, - tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None) + tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), + tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), ) - + # JumpStart provides recommended instance type - if hasattr(deploy_kwargs, 'instance_type') and deploy_kwargs.instance_type: + if hasattr(deploy_kwargs, "instance_type") and deploy_kwargs.instance_type: return deploy_kwargs.instance_type - + except Exception: pass - + return None def _get_default_instance_type(self) -> str: """Get optimal default instance type based on model characteristics. - + Analyzes the model to determine appropriate instance type: - JumpStart models: Use recommended instance type from metadata - HuggingFace models: Analyze model size and tags for GPU requirements - Fallback: ml.m5.large for CPU workloads - + Returns: Instance type string (e.g., 'ml.g5.xlarge', 'ml.m5.large'). """ logger.debug("Auto-detecting optimal instance type for model...") - + if isinstance(self.model, str) and self._is_jumpstart_model_id(): recommended_type = self._get_jumpstart_recommended_instance_type() if recommended_type: logger.debug(f"Using JumpStart recommended instance type: {recommended_type}") return recommended_type - + # For HuggingFace models, use metadata to detect requirements elif isinstance(self.model, str): try: - env_vars = getattr(self, 'env_vars', {}) or {} + env_vars = getattr(self, "env_vars", {}) or {} hf_model_md = self.get_huggingface_model_metadata( - self.model, - env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + # Check model size from metadata model_size = hf_model_md.get("safetensors", {}).get("total", 0) model_tags = hf_model_md.get("tags", []) - + # Large models or specific tags indicate GPU need - if (model_size > 2_000_000_000 or # > 2GB - any(tag in model_tags for tag in ["7b", "13b", "70b"]) or - "7b" in self.model.lower() or "13b" in self.model.lower()): + if ( + model_size > 2_000_000_000 # > 2GB + or any(tag in model_tags for tag in ["7b", "13b", "70b"]) + or "7b" in self.model.lower() + or "13b" in self.model.lower() + ): logger.debug("Detected large model, using GPU instance type: ml.g5.xlarge") return "ml.g5.xlarge" - + except Exception as e: logger.debug(f"Could not get HF metadata for smart detection: {e}") - + # Default fallback logger.debug("Using default CPU instance type: ml.m5.large") return "ml.m5.large" - + # ======================================== # Image Detection and Container Utils # ======================================== def _auto_detect_container_default(self) -> str: """Auto-detect container image for framework-based models. - + Detects the appropriate Deep Learning Container (DLC) based on: - Model framework (PyTorch, TensorFlow) - Framework version from HuggingFace metadata - Python version compatibility - Instance type requirements - + Returns: Container image URI string. - + Raises: ValueError: If instance type not specified or no compatible image found. """ from sagemaker.core import image_uris - + logger.debug("Auto-detecting image since image_uri was not provided in ModelBuilder()") - if not getattr(self, 'instance_type', None): + if not getattr(self, "instance_type", None): raise ValueError( "Instance type is not specified. " "Unable to detect if the container needs to be GPU or CPU." @@ -403,13 +408,12 @@ def _auto_detect_container_default(self) -> str: ) py_tuple = platform.python_version_tuple() - env_vars = getattr(self, 'env_vars', {}) or {} - + env_vars = getattr(self, "env_vars", {}) or {} + torch_v, tf_v, base_hf_v, _ = self._get_hf_framework_versions( - self.model, - env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + if torch_v: fw, fw_version = "pytorch", torch_v elif tf_v: @@ -446,21 +450,20 @@ def _auto_detect_container_default(self) -> str: f"framework version {fw_version} and python version py{py_tuple[0]}{py_tuple[1]}. " f"Please manually provide image_uri to ModelBuilder()" ) - def _get_smd_image_uri(self, processing_unit: Optional[str] = None) -> str: """Get SageMaker Distribution (SMD) inference image URI. - + Retrieves the appropriate SMD container image for custom orchestrator deployment. Requires Python >= 3.12 for SMD inference. - + Args: processing_unit: Target processing unit ('cpu' or 'gpu'). If None, defaults to 'cpu'. - + Returns: SMD inference image URI string. - + Raises: ValueError: If Python version < 3.12 or invalid processing unit. """ @@ -468,8 +471,9 @@ def _get_smd_image_uri(self, processing_unit: Optional[str] = None) -> str: from sagemaker.core import image_uris if not self.sagemaker_session: - if hasattr(self, 'region') and self.region: + if hasattr(self, "region") and self.region: import boto3 + boto_session = boto3.Session(region_name=self.region) self.sagemaker_session = Session(boto_session=boto_session) else: @@ -484,14 +488,16 @@ def _get_smd_image_uri(self, processing_unit: Optional[str] = None) -> str: INSTANCE_TYPES = {"cpu": "ml.c5.xlarge", "gpu": "ml.g5.4xlarge"} effective_processing_unit = processing_unit or "cpu" - + if effective_processing_unit not in INSTANCE_TYPES: raise ValueError( f"Invalid processing unit '{effective_processing_unit}'. " f"Must be one of: {list(INSTANCE_TYPES.keys())}" ) - logger.debug("Finding SMD inference image URI for a %s instance.", effective_processing_unit) + logger.debug( + "Finding SMD inference image URI for a %s instance.", effective_processing_unit + ) smd_uri = image_uris.retrieve( framework="sagemaker-distribution", @@ -502,163 +508,172 @@ def _get_smd_image_uri(self, processing_unit: Optional[str] = None) -> str: logger.debug("Found compatible image: %s", smd_uri) return smd_uri - def _is_huggingface_model(self) -> bool: """Check if model is a HuggingFace model ID. - + Determines if the model string represents a HuggingFace model by: - Checking for organization/model-name format - Checking explicit model_type designation - Fallback: assume HuggingFace if not JumpStart - + Returns: True if model appears to be a HuggingFace model ID. """ if not isinstance(self.model, str): return False - + # Simple pattern matching for HuggingFace model IDs # Format: "organization/model-name" or just "model-name" - model_type = getattr(self, 'model_type', None) + model_type = getattr(self, "model_type", None) if "/" in self.model or model_type == "huggingface": return True - + # Additional check: if it's not a JumpStart model, assume HuggingFace return not self._is_jumpstart_model_id() - - def _get_supported_version(self, hf_config: Dict[str, Any], hugging_face_version: str, base_fw: str) -> str: + def _get_supported_version( + self, hf_config: Dict[str, Any], hugging_face_version: str, base_fw: str + ) -> str: """Extract supported framework version from HuggingFace config. - + Uses the HuggingFace JSON config to pick the best supported version for the given framework. - + Args: hf_config: HuggingFace configuration dictionary hugging_face_version: HuggingFace transformers version base_fw: Base framework name (e.g., 'pytorch', 'tensorflow') - + Returns: Best supported framework version string. """ version_config = hf_config.get("versions", {}).get(hugging_face_version, {}) versions_to_return = [] - + for key in version_config.keys(): if key.startswith(base_fw): - base_fw_version = key[len(base_fw):] + base_fw_version = key[len(base_fw) :] if len(hugging_face_version.split(".")) == 2: base_fw_version = ".".join(base_fw_version.split(".")[:-1]) versions_to_return.append(base_fw_version) - + if not versions_to_return: raise ValueError(f"No supported versions found for framework {base_fw}") - + return sorted(versions_to_return, reverse=True)[0] - def _get_hf_framework_versions(self, model_id: str, hf_token: Optional[str] = None) -> Tuple[Optional[str], Optional[str], str, str]: + def _get_hf_framework_versions( + self, model_id: str, hf_token: Optional[str] = None + ) -> Tuple[Optional[str], Optional[str], str, str]: """Get HuggingFace framework versions for image_uris.retrieve(). - + Analyzes HuggingFace model metadata to determine the appropriate framework versions for container image selection. - + Args: model_id: HuggingFace model identifier hf_token: Optional HuggingFace API token for private models - + Returns: Tuple of (pytorch_version, tensorflow_version, transformers_version, py_version). One of pytorch_version or tensorflow_version will be None. - + Raises: ValueError: If no supported framework versions found. """ from sagemaker.core import image_uris - + # Get model metadata for framework detection hf_model_md = self.get_huggingface_model_metadata(model_id, hf_token) - + # Get HuggingFace framework configuration hf_config = image_uris.config_for_framework("huggingface").get("inference") config = hf_config["versions"] base_hf_version = sorted(config.keys(), key=lambda v: Version(v), reverse=True)[0] - + model_tags = hf_model_md.get("tags", []) - + # Detect framework from model tags if "pytorch" in model_tags: pytorch_version = self._get_supported_version(hf_config, base_hf_version, "pytorch") - py_version = config[base_hf_version][f"pytorch{pytorch_version}"].get("py_versions", [])[-1] + py_version = config[base_hf_version][f"pytorch{pytorch_version}"].get( + "py_versions", [] + )[-1] return pytorch_version, None, base_hf_version, py_version - + elif "keras" in model_tags or "tensorflow" in model_tags: - tensorflow_version = self._get_supported_version(hf_config, base_hf_version, "tensorflow") - py_version = config[base_hf_version][f"tensorflow{tensorflow_version}"].get("py_versions", [])[-1] + tensorflow_version = self._get_supported_version( + hf_config, base_hf_version, "tensorflow" + ) + py_version = config[base_hf_version][f"tensorflow{tensorflow_version}"].get( + "py_versions", [] + )[-1] return None, tensorflow_version, base_hf_version, py_version - + else: # Default to PyTorch if no framework detected (matches V2 behavior) pytorch_version = self._get_supported_version(hf_config, base_hf_version, "pytorch") - py_version = config[base_hf_version][f"pytorch{pytorch_version}"].get("py_versions", [])[-1] + py_version = config[base_hf_version][f"pytorch{pytorch_version}"].get( + "py_versions", [] + )[-1] return pytorch_version, None, base_hf_version, py_version - def _detect_jumpstart_image(self) -> None: """Detect and set image URI for JumpStart models. - + Uses JumpStart metadata to determine the appropriate container image and framework information for the model. - + Raises: ValueError: If image URI cannot be determined or JumpStart lookup fails. """ try: init_kwargs = get_init_kwargs( model_id=self.model, - model_version=getattr(self, 'model_version', None) or "*", + model_version=getattr(self, "model_version", None) or "*", region=self.region, - instance_type=getattr(self, 'instance_type', None), - tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None) + instance_type=getattr(self, "instance_type", None), + tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), + tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), ) - + self.image_uri = init_kwargs.get("image_uri") if not self.image_uri: raise ValueError(f"Could not determine image URI for JumpStart model: {self.model}") - + logger.debug("Auto-detected JumpStart image: %s", self.image_uri) self.framework, self.framework_version = self._extract_framework_from_image_uri() - + except Exception as e: - raise ValueError(f"Failed to auto-detect image for JumpStart model {self.model}: {e}") from e + raise ValueError( + f"Failed to auto-detect image for JumpStart model {self.model}: {e}" + ) from e - def _detect_huggingface_image(self) -> None: """Detect and set image URI for HuggingFace models based on model server. - + Automatically selects the appropriate container image based on: - Explicit model_server setting - Model task type from HuggingFace metadata - Framework requirements and versions - + Raises: ValueError: If image detection fails or unsupported model server. """ from sagemaker.core import image_uris - + try: - env_vars = getattr(self, 'env_vars', {}) or {} - + env_vars = getattr(self, "env_vars", {}) or {} + # Determine which model server we're using - model_server = getattr(self, 'model_server', None) + model_server = getattr(self, "model_server", None) if not model_server: # Auto-select model server based on HF model task hf_model_md = self.get_huggingface_model_metadata( - self.model, - env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN") ) model_task = hf_model_md.get("pipeline_tag") - + if model_task == "text-generation": effective_model_server = ModelServer.TGI elif model_task in ["sentence-similarity", "feature-extraction"]: @@ -667,7 +682,7 @@ def _detect_huggingface_image(self) -> None: effective_model_server = ModelServer.MMS # Transformers else: effective_model_server = model_server - + # Choose image based on effective model server if effective_model_server == ModelServer.TGI: # TGI: Use image_uris.retrieve with "huggingface-llm" framework @@ -684,11 +699,11 @@ def _detect_huggingface_image(self) -> None: self.image_uri = image_uris.retrieve( framework="huggingface-tei", image_scope="inference", - instance_type=getattr(self, 'instance_type', None), + instance_type=getattr(self, "instance_type", None), region=self.region, ) self.framework = Framework.HUGGINGFACE - + elif effective_model_server == ModelServer.DJL_SERVING: # DJL: Use image_uris.retrieve with "djl-lmi" framework (matches DJLModel default) self.image_uri = image_uris.retrieve( @@ -696,109 +711,108 @@ def _detect_huggingface_image(self) -> None: region=self.region, version="latest", image_scope="inference", - instance_type=getattr(self, 'instance_type', None) + instance_type=getattr(self, "instance_type", None), ) self.framework = Framework.DJL - + elif effective_model_server == ModelServer.MMS: # Transformers # Transformers: Use HuggingFace framework with detected versions - pytorch_version, tensorflow_version, transformers_version, py_version = \ + pytorch_version, tensorflow_version, transformers_version, py_version = ( self._get_hf_framework_versions( - self.model, - env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + ) + base_framework_version = ( - f"pytorch{pytorch_version}" if pytorch_version + f"pytorch{pytorch_version}" + if pytorch_version else f"tensorflow{tensorflow_version}" ) - + self.image_uri = image_uris.retrieve( framework="huggingface", region=self.region, version=transformers_version, py_version=py_version, - instance_type=getattr(self, 'instance_type', None), + instance_type=getattr(self, "instance_type", None), image_scope="inference", base_framework_version=base_framework_version, ) self.framework = Framework.HUGGINGFACE - + elif effective_model_server == ModelServer.TORCHSERVE: # TorchServe: Use HuggingFace framework with detected versions - pytorch_version, tensorflow_version, transformers_version, py_version = \ + pytorch_version, tensorflow_version, transformers_version, py_version = ( self._get_hf_framework_versions( - self.model, - env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.model, env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + ) + base_framework_version = ( - f"pytorch{pytorch_version}" if pytorch_version + f"pytorch{pytorch_version}" + if pytorch_version else f"tensorflow{tensorflow_version}" ) - + self.image_uri = image_uris.retrieve( framework="huggingface", region=self.region, version=transformers_version, py_version=py_version, - instance_type=getattr(self, 'instance_type', None), + instance_type=getattr(self, "instance_type", None), image_scope="inference", base_framework_version=base_framework_version, ) self.framework = Framework.HUGGINGFACE - + elif effective_model_server == ModelServer.TRITON: # Triton: Uses custom image construction (not image_uris.retrieve) raise ValueError( "Triton image detection for HuggingFace models requires custom implementation" ) - + elif effective_model_server == ModelServer.TENSORFLOW_SERVING: # TensorFlow Serving: V2 required explicit image_uri (no auto-detection) raise ValueError("TensorFlow Serving requires explicit image_uri specification") - + elif effective_model_server == ModelServer.SMD: # SMD: Uses _get_smd_image_uri helper cpu_or_gpu = self._get_processing_unit() self.image_uri = self._get_smd_image_uri(processing_unit=cpu_or_gpu) self.framework = Framework.SMD - + else: raise ValueError( f"Unsupported model server for HuggingFace models: {effective_model_server}" ) - + logger.debug("Auto-detected HuggingFace image: %s", self.image_uri) - + except Exception as e: raise ValueError( f"Failed to auto-detect image for HuggingFace model {self.model}: {e}" ) from e - def _detect_model_object_image(self) -> None: """Detect image for legacy object-based models. - + Handles model objects (not string IDs) by using the auto_detect_container function to determine appropriate container image. - + Raises: ValueError: If neither model nor inference_spec available for detection. """ - model = getattr(self, 'model', None) - inference_spec = getattr(self, 'inference_spec', None) - model_path = getattr(self, 'model_path', None) - + model = getattr(self, "model", None) + inference_spec = getattr(self, "inference_spec", None) + model_path = getattr(self, "model_path", None) + if model: logger.debug( "Auto-detecting container URL for the provided model on instance %s", - getattr(self, 'instance_type', None), + getattr(self, "instance_type", None), ) self.image_uri, fw, self.framework_version = auto_detect_container( - model, - self.region, - getattr(self, 'instance_type', None) + model, self.region, getattr(self, "instance_type", None) ) self.framework = self._normalize_framework_to_enum(fw) @@ -810,27 +824,26 @@ def _detect_model_object_image(self) -> None: self.image_uri, fw, self.framework_version = auto_detect_container( inference_spec.load(model_path), self.region, - getattr(self, 'instance_type', None), + getattr(self, "instance_type", None), ) self.framework = self._normalize_framework_to_enum(fw) else: raise ValueError("Cannot detect required model or inference spec") - def _auto_detect_image_uri(self) -> None: """Auto-detect container image URI based on model type. - + Determines the appropriate container image by: 1. Using provided image_uri if available 2. For string models: JumpStart vs HuggingFace detection 3. For object models: Legacy auto-detection - + Sets self.image_uri, self.framework, and self.framework_version. - + Raises: ValueError: If image cannot be auto-detected for the model type. """ - image_uri = getattr(self, 'image_uri', None) + image_uri = getattr(self, "image_uri", None) if image_uri: self.framework, self.framework_version = self._extract_framework_from_image_uri() logger.debug("Skipping auto-detection as image_uri is provided: %s", image_uri) @@ -840,13 +853,13 @@ def _auto_detect_image_uri(self) -> None: self._detect_inference_image_from_training() return - model = getattr(self, 'model', None) - inference_spec = getattr(self, 'inference_spec', None) + model = getattr(self, "model", None) + inference_spec = getattr(self, "inference_spec", None) if isinstance(model, str): # V3: String-based model detection - model_type = getattr(self, 'model_type', None) - + model_type = getattr(self, "model_type", None) + # First priority: Use model_type if it indicates JumpStart if model_type in ["open_weights", "proprietary"]: self._detect_jumpstart_image() @@ -858,38 +871,40 @@ def _auto_detect_image_uri(self) -> None: self._detect_huggingface_image() else: raise ValueError(f"Cannot auto-detect image for model: {model}") - elif inference_spec and hasattr(inference_spec, 'get_model'): + elif inference_spec and hasattr(inference_spec, "get_model"): try: spec_model = inference_spec.get_model() if spec_model is None: logger.warning( - "InferenceSpec.get_model() returned None. If you are using a JumpStar or HuggingFace model, you may need to implement get_model() in your InferenceSpec class") - + "InferenceSpec.get_model() returned None. If you are using a JumpStar or HuggingFace model, you may need to implement get_model() in your InferenceSpec class" + ) + if isinstance(spec_model, str): # Temporarily set model for detection, then restore original_model = self.model self.model = spec_model - + # Use existing detection logic if self._is_jumpstart_model_id(): self._detect_jumpstart_image() elif self._is_huggingface_model(): self._detect_huggingface_image() else: - raise ValueError(f"Cannot auto-detect image for inference_spec model: {spec_model}") - + raise ValueError( + f"Cannot auto-detect image for inference_spec model: {spec_model}" + ) + # Restore original model self.model = original_model return except Exception as e: pass - + # Fall back to existing object detection self._detect_model_object_image() else: # V2: Object-based model detection self._detect_model_object_image() - # ======================================== # HuggingFace Jumpstart Utils @@ -897,32 +912,32 @@ def _auto_detect_image_uri(self) -> None: def _use_jumpstart_equivalent(self) -> bool: """Check if HuggingFace model has JumpStart equivalent and use it. - + Replaces the HuggingFace model with its JumpStart equivalent if available. Skips replacement if image_uri or env_vars are explicitly provided. - + Returns: True if JumpStart equivalent was found and used, False otherwise. """ # Do not use the equivalent JS model if image_uri or env_vars is provided - image_uri = getattr(self, 'image_uri', None) - env_vars = getattr(self, 'env_vars', None) + image_uri = getattr(self, "image_uri", None) + env_vars = getattr(self, "env_vars", None) if image_uri or env_vars: return False - + if not hasattr(self, "_has_jumpstart_equivalent"): self._jumpstart_mapping = self._retrieve_hugging_face_model_mapping() self._has_jumpstart_equivalent = self.model in self._jumpstart_mapping - + if self._has_jumpstart_equivalent: # Use schema builder from HF model metadata - schema_builder = getattr(self, 'schema_builder', None) + schema_builder = getattr(self, "schema_builder", None) if not schema_builder: model_task = None - model_metadata = getattr(self, 'model_metadata', None) + model_metadata = getattr(self, "model_metadata", None) if model_metadata: model_task = model_metadata.get("HF_TASK") - + hf_model_md = self.get_huggingface_model_metadata(self.model) if not model_task: model_task = hf_model_md.get("pipeline_tag") @@ -933,19 +948,19 @@ def _use_jumpstart_equivalent(self) -> bool: jumpstart_model_id = self._jumpstart_mapping[huggingface_model_id]["jumpstart-model-id"] self.model = jumpstart_model_id merged_date = self._jumpstart_mapping[huggingface_model_id].get("merged-at") - + # Call _build_for_jumpstart if method exists - if hasattr(self, '_build_for_jumpstart'): + if hasattr(self, "_build_for_jumpstart"): self._build_for_jumpstart() - + compare_model_diff_message = ( "If you want to identify the differences between the two, " "please use model_uris.retrieve() to retrieve the model " "artifact S3 URI and compare them." ) - - is_gated = (hasattr(self, '_is_gated_model') and self._is_gated_model()) - + + is_gated = hasattr(self, "_is_gated_model") and self._is_gated_model() + logger.warning( "Please note that for this model we are using the JumpStart's " f'local copy "{jumpstart_model_id}" ' @@ -958,7 +973,6 @@ def _use_jumpstart_equivalent(self) -> bool: return True return False - def _hf_schema_builder_init(self, model_task: str) -> None: """Initialize schema builder for HuggingFace model task. @@ -976,8 +990,7 @@ def _hf_schema_builder_init(self, model_task: str) -> None: sample_inputs, sample_outputs = task.retrieve_local_schemas(model_task) except ValueError: # Samples could not be loaded locally, try to fetch remote HF schema - from sagemaker_schema_inference_artifacts.huggingface import \ - remote_schema_retriever + from sagemaker_schema_inference_artifacts.huggingface import remote_schema_retriever if model_task in ("text-to-image", "automatic-speech-recognition"): logger.warning( @@ -985,37 +998,36 @@ def _hf_schema_builder_init(self, model_task: str) -> None: "with all models at this time.", model_task, ) - + remote_hf_schema_helper = remote_schema_retriever.RemoteSchemaRetriever() ( sample_inputs, sample_outputs, ) = remote_hf_schema_helper.get_resolved_hf_schema_for_task(model_task) - + self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs) - + except ValueError as e: raise TaskNotFoundException( f"HuggingFace Schema builder samples for {model_task} could not be found " f"locally or via remote." ) from e - def _retrieve_hugging_face_model_mapping(self) -> Dict[str, Dict[str, Any]]: """Retrieve and preprocess HuggingFace/JumpStart model mapping. - + Downloads the mapping file from S3 that contains the correspondence between HuggingFace model IDs and their JumpStart equivalents. - + Returns: Dictionary mapping HuggingFace model IDs to JumpStart model metadata. Empty dict if mapping cannot be retrieved. """ converted_mapping = {} - session = getattr(self, 'sagemaker_session', None) + session = getattr(self, "sagemaker_session", None) if not session: return converted_mapping - + region = session.boto_region_name try: mapping_json_object = JumpStartS3PayloadAccessor.get_object_cached( @@ -1039,22 +1051,19 @@ def _retrieve_hugging_face_model_mapping(self) -> Dict[str, Dict[str, Any]]: def _prepare_hf_model_for_upload(self) -> None: """Download HuggingFace model metadata for upload. - + Creates a temporary directory and downloads the necessary HuggingFace model metadata files if model_path is not already set. """ - model_path = getattr(self, 'model_path', None) + model_path = getattr(self, "model_path", None) if not model_path: self.model_path = f"/tmp/sagemaker/model-builder/{self.model}" - env_vars = getattr(self, 'env_vars', {}) or {} + env_vars = getattr(self, "env_vars", {}) or {} self.download_huggingface_model_metadata( self.model, os.path.join(self.model_path, "code"), env_vars.get("HUGGING_FACE_HUB_TOKEN"), ) - - - # ======================================== # Resource and Hardware Utils @@ -1062,60 +1071,63 @@ def _prepare_hf_model_for_upload(self) -> None: def _get_processing_unit(self) -> str: """Detect if resource requirements are intended for CPU or GPU instance. - + Analyzes resource requirements to determine the target processing unit: - Checks for accelerator requirements in resource_requirements - Checks for accelerator requirements in modelbuilder_list items - Defaults to CPU if no accelerators specified - + Returns: 'gpu' if accelerators are required, 'cpu' otherwise. """ # Assume custom orchestrator will be deployed as an endpoint to a CPU instance - resource_requirements = getattr(self, 'resource_requirements', None) - if not resource_requirements or not getattr(resource_requirements, 'num_accelerators', None): - modelbuilder_list = getattr(self, 'modelbuilder_list', None) or [] + resource_requirements = getattr(self, "resource_requirements", None) + if not resource_requirements or not getattr( + resource_requirements, "num_accelerators", None + ): + modelbuilder_list = getattr(self, "modelbuilder_list", None) or [] for ic in modelbuilder_list: - ic_resource_req = getattr(ic, 'resource_requirements', None) - if ic_resource_req and getattr(ic_resource_req, 'num_accelerators', 0) > 0: + ic_resource_req = getattr(ic, "resource_requirements", None) + if ic_resource_req and getattr(ic_resource_req, "num_accelerators", 0) > 0: return "gpu" return "cpu" - - if getattr(resource_requirements, 'num_accelerators', 0) > 0: + + if getattr(resource_requirements, "num_accelerators", 0) > 0: return "gpu" return "cpu" - def _get_inference_component_resource_requirements(self, mb) -> None: """Fetch pre-benchmarked resource requirements from JumpStart. - + Attempts to retrieve and set resource requirements for inference components using JumpStart deployment configurations when available. - + Raises: ValueError: If no resource requirements provided and no JumpStart configs found. """ - resource_requirements = getattr(mb, 'resource_requirements', None) + resource_requirements = getattr(mb, "resource_requirements", None) if mb._is_jumpstart_model_id() and not resource_requirements: - if not hasattr(mb, 'list_deployment_configs'): + if not hasattr(mb, "list_deployment_configs"): return - + deployment_configs = mb.list_deployment_configs() if not deployment_configs: - inference_component_name = getattr(mb, 'inference_component_name', 'Unknown') + inference_component_name = getattr(mb, "inference_component_name", "Unknown") raise ValueError( f"No resource requirements were provided for Inference Component " f"{inference_component_name} and no default deployment " f"configs were found in JumpStart." ) - + compute_requirements = ( - deployment_configs[0].get("DeploymentArgs", {}).get("ComputeResourceRequirements", {}) + deployment_configs[0] + .get("DeploymentArgs", {}) + .get("ComputeResourceRequirements", {}) ) - + logger.debug("Retrieved pre-benchmarked deployment configurations from JumpStart.") - + mb.resource_requirements = ResourceRequirements( requests={ "memory": compute_requirements.get("MinMemoryRequiredInMb"), @@ -1127,9 +1139,8 @@ def _get_inference_component_resource_requirements(self, mb) -> None: }, limits={"memory": compute_requirements.get("MaxMemoryRequiredInMb", None)}, ) - + return mb - def _can_fit_on_single_gpu(self) -> bool: """Check if model can fit on a single GPU. @@ -1141,17 +1152,16 @@ def _can_fit_on_single_gpu(self) -> bool: True if model size <= single GPU memory size, False otherwise. """ try: - if not hasattr(self, '_try_fetch_gpu_info'): + if not hasattr(self, "_try_fetch_gpu_info"): return False - + single_gpu_size_mib = self._try_fetch_gpu_info() - env_vars = getattr(self, 'env_vars', {}) or {} - + env_vars = getattr(self, "env_vars", {}) or {} + model_size_mib = _total_inference_model_size_mib( - self.model, - env_vars.get("dtypes", "float32") + self.model, env_vars.get("dtypes", "float32") ) - + if model_size_mib <= single_gpu_size_mib: logger.debug( "Total inference model size: %s MiB, single GPU size: %s MiB", @@ -1160,56 +1170,53 @@ def _can_fit_on_single_gpu(self) -> bool: ) return True return False - + except ValueError: - instance_type = getattr(self, 'instance_type', 'Unknown') + instance_type = getattr(self, "instance_type", "Unknown") logger.debug("Unable to determine single GPU size for instance %s", instance_type) return False - - # ======================================== # Serialization Utils # ======================================== def _extract_framework_from_image_uri(self) -> Tuple[Optional[Framework], Optional[str]]: """Extract framework and version information from SageMaker image URI. - + Analyzes the container image URI to determine the ML framework and version being used. - + Returns: Tuple of (Framework enum, version string). Both can be None if not detected. """ - image_uri = getattr(self, 'image_uri', None) + image_uri = getattr(self, "image_uri", None) if not image_uri: return None, None - + if "pytorch-inference" in image_uri or "pytorch-training" in image_uri: - version_match = re.search(r'pytorch.*:(\d+\.\d+\.\d+)', image_uri) + version_match = re.search(r"pytorch.*:(\d+\.\d+\.\d+)", image_uri) return Framework.PYTORCH, version_match.group(1) if version_match else None - + elif "tensorflow-inference" in image_uri or "tensorflow-training" in image_uri: - version_match = re.search(r'tensorflow.*:(\d+\.\d+\.\d+)', image_uri) + version_match = re.search(r"tensorflow.*:(\d+\.\d+\.\d+)", image_uri) return Framework.TENSORFLOW, version_match.group(1) if version_match else None - + elif "sagemaker-xgboost" in image_uri: - version_match = re.search(r'sagemaker-xgboost:(\d+\.\d+)', image_uri) + version_match = re.search(r"sagemaker-xgboost:(\d+\.\d+)", image_uri) return Framework.XGBOOST, version_match.group(1) if version_match else None - + elif "sagemaker-scikit-learn" in image_uri: - version_match = re.search(r'scikit-learn:(\d+\.\d+)', image_uri) + version_match = re.search(r"scikit-learn:(\d+\.\d+)", image_uri) return Framework.SKLEARN, version_match.group(1) if version_match else None - + elif "huggingface" in image_uri: return Framework.HUGGINGFACE, None - + elif "mxnet" in image_uri: - version_match = re.search(r'mxnet.*:(\d+\.\d+\.\d+)', image_uri) + version_match = re.search(r"mxnet.*:(\d+\.\d+\.\d+)", image_uri) return Framework.MXNET, version_match.group(1) if version_match else None - + return None, None - def _fetch_serializer_and_deserializer_for_framework(self, framework: str) -> Tuple[Any, Any]: """Fetch default serializer and deserializer for a framework. @@ -1225,26 +1232,27 @@ def _fetch_serializer_and_deserializer_for_framework(self, framework: str) -> Tu if framework_enum and framework_enum in DEFAULT_SERIALIZERS_BY_FRAMEWORK: return DEFAULT_SERIALIZERS_BY_FRAMEWORK[framework_enum] return NumpySerializer(), JSONDeserializer() - - def _normalize_framework_to_enum(self, framework: Union[str, Framework, None]) -> Optional[Framework]: + def _normalize_framework_to_enum( + self, framework: Union[str, Framework, None] + ) -> Optional[Framework]: """Convert any framework input to Framework enum. - + Args: framework: Framework as string, enum, or None - + Returns: Framework enum or None if not found/None input """ if framework is None: return None - + if isinstance(framework, Framework): return framework - + if not isinstance(framework, str): return None - + framework_mapping = { "xgboost": Framework.XGBOOST, "xgb": Framework.XGBOOST, @@ -1269,9 +1277,8 @@ def _normalize_framework_to_enum(self, framework: Union[str, Framework, None]) - "smd": Framework.SMD, "sagemaker-distribution": Framework.SMD, } - - return framework_mapping.get(framework.lower()) + return framework_mapping.get(framework.lower()) # ======================================== # MLflow Utils @@ -1279,7 +1286,7 @@ def _normalize_framework_to_enum(self, framework: Union[str, Framework, None]) - def _handle_mlflow_input(self) -> None: """Check and handle MLflow model input if present. - + Detects MLflow model arguments, validates metadata existence, and initializes MLflow-specific configurations. """ @@ -1287,19 +1294,19 @@ def _handle_mlflow_input(self) -> None: if not self._is_mlflow_model: return - model_metadata = getattr(self, 'model_metadata', {}) + model_metadata = getattr(self, "model_metadata", {}) mlflow_model_path = model_metadata.get(MLFLOW_MODEL_PATH) if not mlflow_model_path: return - + artifact_path = self._get_artifact_path(mlflow_model_path) if not self._mlflow_metadata_exists(artifact_path): return self._initialize_for_mlflow(artifact_path) - - model_server = getattr(self, 'model_server', None) - env_vars = getattr(self, 'env_vars', {}) or {} + + model_server = getattr(self, "model_server", None) + env_vars = getattr(self, "env_vars", {}) or {} _validate_input_for_mlflow(model_server, env_vars.get("MLFLOW_MODEL_FLAVOR")) def _has_mlflow_arguments(self) -> bool: @@ -1308,9 +1315,9 @@ def _has_mlflow_arguments(self) -> bool: Returns: True if MLflow arguments are present and should be handled, False otherwise. """ - inference_spec = getattr(self, 'inference_spec', None) - model = getattr(self, 'model', None) - + inference_spec = getattr(self, "inference_spec", None) + model = getattr(self, "model", None) + if inference_spec or model: logger.debug( "Either inference spec or model is provided. " @@ -1318,7 +1325,7 @@ def _has_mlflow_arguments(self) -> bool: ) return False - model_metadata = getattr(self, 'model_metadata', None) + model_metadata = getattr(self, "model_metadata", None) if not model_metadata: logger.debug( "No ModelMetadata provided. ModelBuilder is not handling MLflow model input" @@ -1350,16 +1357,16 @@ def _get_artifact_path(self, mlflow_model_path: str) -> str: Returns: Path to the model artifact. - + Raises: ValueError: If tracking ARN not provided for run/registry paths. ImportError: If sagemaker_mlflow not installed. """ is_run_id_type = re.match(MLFLOW_RUN_ID_REGEX, mlflow_model_path) is_registry_type = re.match(MLFLOW_REGISTRY_PATH_REGEX, mlflow_model_path) - + if is_run_id_type or is_registry_type: - model_metadata = getattr(self, 'model_metadata', {}) + model_metadata = getattr(self, "model_metadata", {}) mlflow_tracking_arn = model_metadata.get(MLFLOW_TRACKING_ARN) if not mlflow_tracking_arn: raise ValueError( @@ -1375,7 +1382,7 @@ def _get_artifact_path(self, mlflow_model_path: str) -> str: import mlflow mlflow.set_tracking_uri(mlflow_tracking_arn) - + if is_run_id_type: _, run_id, model_path = mlflow_model_path.split("/", 2) artifact_uri = mlflow.get_run(run_id).info.artifact_uri @@ -1391,7 +1398,9 @@ def _get_artifact_path(self, mlflow_model_path: str) -> str: if "@" in mlflow_model_path: _, model_name_and_alias, artifact_uri = mlflow_model_path.split("/", 2) model_name, model_alias = model_name_and_alias.split("@") - model_version_info = mlflow_client.get_model_version_by_alias(model_name, model_alias) + model_version_info = mlflow_client.get_model_version_by_alias( + model_name, model_alias + ) source = mlflow_client.get_model_version_download_uri( model_name, model_version_info.version ) @@ -1405,7 +1414,7 @@ def _get_artifact_path(self, mlflow_model_path: str) -> str: # Handle model package ARN if re.match(MODEL_PACKAGE_ARN_REGEX, mlflow_model_path): - sagemaker_session = getattr(self, 'sagemaker_session', None) + sagemaker_session = getattr(self, "sagemaker_session", None) if sagemaker_session: model_package = sagemaker_session.sagemaker_client.describe_model_package( ModelPackageName=mlflow_model_path @@ -1420,7 +1429,7 @@ def _mlflow_metadata_exists(self, path: str) -> bool: Args: path: Directory path to check (local or S3). - + Returns: True if MLmodel file exists, False otherwise. """ @@ -1429,7 +1438,7 @@ def _mlflow_metadata_exists(self, path: str) -> bool: if not path.endswith("/"): path += "/" s3_uri_to_mlmodel_file = f"{path}{MLFLOW_METADATA_FILE}" - sagemaker_session = getattr(self, 'sagemaker_session', None) + sagemaker_session = getattr(self, "sagemaker_session", None) if not sagemaker_session: return False response = s3_downloader.list(s3_uri_to_mlmodel_file, sagemaker_session) @@ -1446,51 +1455,49 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None: Args: artifact_path: Path to the MLflow artifact store. - + Raises: ValueError: If artifact path is invalid. """ - model_path = getattr(self, 'model_path', None) - sagemaker_session = getattr(self, 'sagemaker_session', None) - + model_path = getattr(self, "model_path", None) + sagemaker_session = getattr(self, "sagemaker_session", None) + if artifact_path.startswith("s3://"): _download_s3_artifacts(artifact_path, model_path, sagemaker_session) elif os.path.exists(artifact_path): _copy_directory_contents(artifact_path, model_path) else: raise ValueError(f"Invalid path: {artifact_path}") - + mlflow_model_metadata_path = _generate_mlflow_artifact_path( model_path, MLFLOW_METADATA_FILE ) mlflow_model_dependency_path = _generate_mlflow_artifact_path( model_path, MLFLOW_PIP_DEPENDENCY_FILE ) - + flavor_metadata = _get_all_flavor_metadata(mlflow_model_metadata_path) deployment_flavor = _get_deployment_flavor(flavor_metadata) - current_model_server = getattr(self, 'model_server', None) + current_model_server = getattr(self, "model_server", None) self.model_server = current_model_server or _get_default_model_server_for_mlflow( deployment_flavor ) - - current_image_uri = getattr(self, 'image_uri', None) + + current_image_uri = getattr(self, "image_uri", None) if not current_image_uri: self.image_uri = _select_container_for_mlflow_model( mlflow_model_src_path=model_path, deployment_flavor=deployment_flavor, region=sagemaker_session.boto_region_name if sagemaker_session else None, - instance_type=getattr(self, 'instance_type', None), + instance_type=getattr(self, "instance_type", None), ) - - env_vars = getattr(self, 'env_vars', {}) - env_vars.update({"MLFLOW_MODEL_FLAVOR": f"{deployment_flavor}"}) - - dependencies = getattr(self, 'dependencies', {}) - dependencies.update({"requirements": mlflow_model_dependency_path}) + env_vars = getattr(self, "env_vars", {}) + env_vars.update({"MLFLOW_MODEL_FLAVOR": f"{deployment_flavor}"}) + dependencies = getattr(self, "dependencies", {}) + dependencies.update({"requirements": mlflow_model_dependency_path}) # ======================================== # Optimize Utils @@ -1512,7 +1519,6 @@ def _is_inferentia_or_trainium(self, instance_type: Optional[str]) -> bool: return True return False - def _is_image_compatible_with_optimization_job(self, image_uri: Optional[str]) -> bool: """Checks whether an instance is compatible with an optimization job. @@ -1526,7 +1532,6 @@ def _is_image_compatible_with_optimization_job(self, image_uri: Optional[str]) - return True return "djl-inference:" in image_uri and ("-lmi" in image_uri or "-neuronx-" in image_uri) - def _deployment_config_contains_draft_model(self, deployment_config: Optional[Dict]) -> bool: """Checks whether a deployment config contains a speculative decoding draft model. @@ -1541,8 +1546,9 @@ def _deployment_config_contains_draft_model(self, deployment_config: Optional[Di deployment_args = deployment_config.get("DeploymentArgs", {}) additional_data_sources = deployment_args.get("AdditionalDataSources") - return "speculative_decoding" in additional_data_sources if additional_data_sources else False - + return ( + "speculative_decoding" in additional_data_sources if additional_data_sources else False + ) def _is_draft_model_jumpstart_provided(self, deployment_config: Optional[Dict]) -> bool: """Checks whether a deployment config's draft model is provided by JumpStart. @@ -1566,7 +1572,6 @@ def _is_draft_model_jumpstart_provided(self, deployment_config: Optional[Dict]) continue return False - def _generate_optimized_model(self, optimization_response: dict): """Generates a new optimization model. @@ -1591,10 +1596,12 @@ def _generate_optimized_model(self, optimization_response: dict): self.instance_type = deployment_instance_type self.add_tags( - {"Key": Tag.OPTIMIZATION_JOB_NAME, "Value": optimization_response["OptimizationJobName"]} + { + "Key": Tag.OPTIMIZATION_JOB_NAME, + "Value": optimization_response["OptimizationJobName"], + } ) - def _is_optimized(self) -> bool: """Checks whether an optimization model is optimized. @@ -1613,7 +1620,6 @@ def _is_optimized(self) -> bool: return True return False - def _generate_model_source( self, model_data: Optional[Union[Dict[str, Any], str]], accept_eula: Optional[bool] ) -> Optional[Dict[str, Any]]: @@ -1637,7 +1643,6 @@ def _generate_model_source( model_source["S3"]["ModelAccessConfig"] = {"AcceptEula": True} return model_source - def _update_environment_variables( self, env: Optional[Dict[str, str]], new_env: Optional[Dict[str, str]] ) -> Optional[Dict[str, str]]: @@ -1657,9 +1662,9 @@ def _update_environment_variables( env = new_env return env - def _extract_speculative_draft_model_provider( - self, speculative_decoding_config: Optional[Dict] = None, + self, + speculative_decoding_config: Optional[Dict] = None, ) -> Optional[str]: """Extracts speculative draft model provider from speculative decoding config. @@ -1685,9 +1690,9 @@ def _extract_speculative_draft_model_provider( return "auto" - def _extract_additional_model_data_source_s3_uri( - self, additional_model_data_source: Optional[Dict] = None, + self, + additional_model_data_source: Optional[Dict] = None, ) -> Optional[str]: """Extracts model data source s3 uri from a model data source in Pascal case. @@ -1705,9 +1710,9 @@ def _extract_additional_model_data_source_s3_uri( return additional_model_data_source.get("S3DataSource").get("S3Uri") - def _extract_deployment_config_additional_model_data_source_s3_uri( - self, additional_model_data_source: Optional[Dict] = None, + self, + additional_model_data_source: Optional[Dict] = None, ) -> Optional[str]: """Extracts model data source s3 uri from a model data source in snake case. @@ -1725,9 +1730,9 @@ def _extract_deployment_config_additional_model_data_source_s3_uri( return additional_model_data_source.get("s3_data_source").get("s3_uri", None) - def _is_draft_model_gated( - self, draft_model_config: Optional[Dict] = None, + self, + draft_model_config: Optional[Dict] = None, ) -> bool: """Extracts model gated-ness from draft model data source. @@ -1739,9 +1744,9 @@ def _is_draft_model_gated( """ return "hosting_eula_key" in draft_model_config if draft_model_config else False - def _extracts_and_validates_speculative_model_source( - self, speculative_decoding_config: Dict, + self, + speculative_decoding_config: Dict, ) -> str: """Extracts model source from speculative decoding config. @@ -1760,7 +1765,6 @@ def _extracts_and_validates_speculative_model_source( raise ValueError("ModelSource must be provided in speculative decoding config.") return model_source - def _generate_channel_name(self, additional_model_data_sources: Optional[List[Dict]]) -> str: """Generates a channel name. @@ -1776,9 +1780,8 @@ def _generate_channel_name(self, additional_model_data_sources: Optional[List[Di return channel_name - def _generate_additional_model_data_sources( - self, + self, model_source: str, channel_name: str, accept_eula: bool = False, @@ -1811,7 +1814,6 @@ def _generate_additional_model_data_sources( return [additional_model_data_source] - def _is_s3_uri(self, s3_uri: Optional[str]) -> bool: """Checks whether an S3 URI is valid. @@ -1826,7 +1828,6 @@ def _is_s3_uri(self, s3_uri: Optional[str]) -> bool: return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None - def _extract_optimization_config_and_env( self, quantization_config: Optional[Dict] = None, @@ -1851,7 +1852,9 @@ def _extract_optimization_config_and_env( compilation_override_env = ( compilation_config.get("OverrideEnvironment") if compilation_config else None ) - sharding_override_env = sharding_config.get("OverrideEnvironment") if sharding_config else None + sharding_override_env = ( + sharding_config.get("OverrideEnvironment") if sharding_config else None + ) if quantization_config is not None: optimization_config["ModelQuantizationConfig"] = quantization_config @@ -1873,7 +1876,6 @@ def _extract_optimization_config_and_env( return None, None, None, None - def _custom_speculative_decoding( self, speculative_decoding_config: Optional[Dict], @@ -1910,7 +1912,7 @@ def _custom_speculative_decoding( def _get_cached_model_specs(self, model_id, version, region, sagemaker_session): """Get cached JumpStart model specs to avoid repeated fetches""" - if not hasattr(self, '_cached_js_model_specs'): + if not hasattr(self, "_cached_js_model_specs"): self._cached_js_model_specs = accessors.JumpStartModelsAccessor.get_model_specs( model_id=model_id, version=version, @@ -1919,7 +1921,6 @@ def _get_cached_model_specs(self, model_id, version, region, sagemaker_session): ) return self._cached_js_model_specs - def _jumpstart_speculative_decoding( self, speculative_decoding_config: Optional[Dict[str, Any]] = None, @@ -1939,7 +1940,7 @@ def _jumpstart_speculative_decoding( "`ModelID` is a required field in `speculative_decoding_config` when " "using JumpStart as draft model provider." ) - + model_version = speculative_decoding_config.get("ModelVersion", "*") accept_eula = speculative_decoding_config.get("AcceptEula", False) channel_name = self._generate_channel_name(self.additional_model_data_sources) @@ -1949,9 +1950,8 @@ def _jumpstart_speculative_decoding( version=model_version, region=sagemaker_session.boto_region_name, sagemaker_session=sagemaker_session, - ) - + model_spec_json = model_specs.to_json() js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket(self.region) @@ -1965,10 +1965,12 @@ def _jumpstart_speculative_decoding( f"{eula_message} Set `AcceptEula`=True in " f"speculative_decoding_config once acknowledged." ) - js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket(self.region) + js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket( + self.region + ) key_prefix = model_spec_json.get("hosting_prepacked_artifact_key") - self.additional_model_data_sources = self. _generate_additional_model_data_sources( + self.additional_model_data_sources = self._generate_additional_model_data_sources( f"s3://{js_bucket}/{key_prefix}", channel_name, accept_eula, @@ -1981,7 +1983,6 @@ def _jumpstart_speculative_decoding( {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"}, ) - def _optimize_for_hf( self, output_path: str, @@ -2026,9 +2027,7 @@ def _optimize_for_hf( sagemaker_session=self.sagemaker_session, ) else: - self._custom_speculative_decoding( - speculative_decoding_config, False - ) + self._custom_speculative_decoding(speculative_decoding_config, False) if quantization_config or compilation_config or sharding_config: create_optimization_job_args = { @@ -2108,7 +2107,6 @@ def _optimize_prepare_for_hf(self): ) self.env_vars.update(env) - def _is_gated_model(self) -> bool: """Determine if ``this`` Model is Gated @@ -2124,7 +2122,7 @@ def _is_gated_model(self) -> bool: if s3_uri is None: return False return "private" in s3_uri - + def set_js_deployment_config(self, config_name: str, instance_type: str) -> None: """Sets the deployment config to apply to the model. @@ -2150,7 +2148,6 @@ def set_js_deployment_config(self, config_name: str, instance_type: str) -> None self.remove_tag_with_key(Tag.FINE_TUNING_MODEL_PATH) self.remove_tag_with_key(Tag.FINE_TUNING_JOB_NAME) - def _set_additional_model_source( self, speculative_decoding_config: Optional[Dict[str, Any]] = None ) -> None: @@ -2161,15 +2158,15 @@ def _set_additional_model_source( accept_eula (Optional[bool]): For models that require a Model Access Config. """ if speculative_decoding_config: - model_provider = self._extract_speculative_draft_model_provider(speculative_decoding_config) + model_provider = self._extract_speculative_draft_model_provider( + speculative_decoding_config + ) channel_name = self._generate_channel_name(self.additional_model_data_sources) if model_provider in ["sagemaker", "auto"]: additional_model_data_sources = ( - self._deployment_config.get("DeploymentArgs", {}).get( - "AdditionalDataSources" - ) + self._deployment_config.get("DeploymentArgs", {}).get("AdditionalDataSources") if self._deployment_config else None ) @@ -2178,8 +2175,9 @@ def _set_additional_model_source( speculative_decoding_config ) if deployment_config: - if model_provider == "sagemaker" and self._is_draft_model_jumpstart_provided( - deployment_config + if ( + model_provider == "sagemaker" + and self._is_draft_model_jumpstart_provided(deployment_config) ): raise ValueError( "No `Sagemaker` provided draft model was found for " @@ -2284,14 +2282,12 @@ def _get_neuron_model_env_vars( version=neuro_model_version, region=self.region, sagemaker_session=self.sagemaker_session, - ) - + model_spec_json = model_specs.to_json() return model_spec_json.get("hosting_env_vars", {}) - - return None + return None def _set_optimization_image_default( self, create_optimization_job_args: Dict[str, Any] @@ -2313,8 +2309,8 @@ def _set_optimization_image_default( region=self.region, model_version=self.model_version, hub_arn=self.hub_arn, - tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None) + tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), + tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), ) default_image = self._get_default_vllm_image(init_kwargs.image_uri) @@ -2393,33 +2389,31 @@ def _parse_lmi_version(self, image: str) -> Tuple[int, int, int]: Returns: Tuple[int, int, int]: LMI version split into major, minor, patch - + Raises: ValueError: If the image format cannot be parsed """ _, dlc_tag = image.split(":") parts = dlc_tag.split("-") - + lmi_version = None for part in parts: if "." in part and part[0].isdigit(): lmi_version = part break - + if not lmi_version: raise ValueError(f"Could not find version in image: {image}") - + version_parts = lmi_version.split(".") if len(version_parts) < 3: raise ValueError(f"Invalid version format: {lmi_version} in image: {image}") - + major_version = int(version_parts[0]) minor_version = int(version_parts[1]) patch_version = int(version_parts[2]) - - return (major_version, minor_version, patch_version) - + return (major_version, minor_version, patch_version) def _optimize_for_jumpstart( self, @@ -2524,7 +2518,9 @@ def _optimize_for_jumpstart( if self._deployment_config else None ) - self.instance_type = instance_type or deployment_config_instance_type or self._get_nb_instance() + self.instance_type = ( + instance_type or deployment_config_instance_type or self._get_nb_instance() + ) create_optimization_job_args = { "OptimizationJobName": job_name, @@ -2549,9 +2545,7 @@ def _optimize_for_jumpstart( if accept_eula: self.accept_eula = accept_eula if isinstance(self.s3_upload_path, dict): - self.s3_upload_path["S3DataSource"]["ModelAccessConfig"] = { - "AcceptEula": True - } + self.s3_upload_path["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True} optimization_env_vars = self._update_environment_variables( optimization_env_vars, @@ -2579,11 +2573,12 @@ def _optimize_for_jumpstart( ) return None - def _generate_optimized_core_model(self, optimization_response: dict) -> Model: """Generate optimized CoreModel from optimization job response.""" - - recommended_image_uri = optimization_response.get("OptimizationOutput", {}).get("RecommendedInferenceImage") + + recommended_image_uri = optimization_response.get("OptimizationOutput", {}).get( + "RecommendedInferenceImage" + ) s3_uri = optimization_response.get("OutputConfig", {}).get("S3OutputLocation") deployment_instance_type = optimization_response.get("DeploymentInstanceType") if recommended_image_uri: @@ -2596,15 +2591,15 @@ def _generate_optimized_core_model(self, optimization_response: dict) -> Model: if deployment_instance_type: self.instance_type = deployment_instance_type - self.add_tags({"Key": "OptimizationJobName", "Value": optimization_response["OptimizationJobName"]}) - + self.add_tags( + {"Key": "OptimizationJobName", "Value": optimization_response["OptimizationJobName"]} + ) + self._optimizing = False optimized_core_model = self._create_model() self.built_model = optimized_core_model - - return optimized_core_model - + return optimized_core_model def deployment_config_response_data( self, @@ -2634,7 +2629,7 @@ def deployment_config_response_data( configs.append(deployment_config_json) return configs - + # @_deployment_config_lru_cache def _get_deployment_configs_benchmarks_data(self) -> Dict[str, Any]: """Deployment configs benchmark metrics. @@ -2695,20 +2690,20 @@ def _get_deployment_configs( sagemaker_session=self.sagemaker_session, image_uri=image_uri, region=self.region, - model_version=getattr(self, 'model_version', None) or "*", + model_version=getattr(self, "model_version", None) or "*", hub_arn=self.hub_arn, - tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None) + tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), + tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), ) deploy_kwargs = get_deploy_kwargs( model_id=self.model, instance_type=instance_type_to_use, sagemaker_session=self.sagemaker_session, region=self.region, - model_version=getattr(self, 'model_version', None) or "*", + model_version=getattr(self, "model_version", None) or "*", hub_arn=self.hub_arn, - tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None) + tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), + tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), ) deployment_config_metadata = DeploymentConfigMetadata( @@ -2725,8 +2720,6 @@ def _get_deployment_configs( return deployment_configs - - # ======================================== # General Utils # ======================================== @@ -2737,7 +2730,7 @@ def add_tags(self, tags: Tags) -> None: Args: tags: Tags to add to the model. """ - current_tags = getattr(self, '_tags', None) + current_tags = getattr(self, "_tags", None) self._tags = _validate_new_tags(tags, current_tags) def remove_tag_with_key(self, key: str) -> None: @@ -2746,107 +2739,109 @@ def remove_tag_with_key(self, key: str) -> None: Args: key: The key of the tag to remove. """ - current_tags = getattr(self, '_tags', None) + current_tags = getattr(self, "_tags", None) self._tags = remove_tag_with_key(key, current_tags) def _get_model_uri(self) -> Optional[str]: """Extract model URI from s3_model_data_url. - + Returns: Model URI string, or None if not available. """ - s3_model_data_url = getattr(self, 's3_model_data_url', None) + s3_model_data_url = getattr(self, "s3_model_data_url", None) if not s3_model_data_url: return None - + if isinstance(s3_model_data_url, (str, PipelineVariable)): return s3_model_data_url elif isinstance(s3_model_data_url, dict): return s3_model_data_url.get("S3DataSource", {}).get("S3Uri", None) return None - def _ensure_base_name_if_needed(self, image_uri: str, script_uri: Optional[str], model_uri: Optional[str]) -> None: + def _ensure_base_name_if_needed( + self, image_uri: str, script_uri: Optional[str], model_uri: Optional[str] + ) -> None: """Create base name from image URI if no model name provided. Uses JumpStart base name if available, otherwise derives from image URI. - + Args: image_uri: Container image URI script_uri: Optional script URI for JumpStart models model_uri: Optional model URI for JumpStart models """ - model_name = getattr(self, 'model_name', None) + model_name = getattr(self, "model_name", None) if model_name is None: - base_name = getattr(self, '_base_name', None) + base_name = getattr(self, "_base_name", None) self._base_name = ( base_name or get_jumpstart_base_name_if_jumpstart_model(script_uri, model_uri) or base_name_from_image(image_uri, default_base_name="ModelBuilder") ) - def _ensure_metadata_configs(self) -> None: """Lazy load JumpStart metadata configs when needed.""" - metadata_configs = getattr(self, '_metadata_configs', None) - model = getattr(self, 'model', None) - + metadata_configs = getattr(self, "_metadata_configs", None) + model = getattr(self, "model", None) + if metadata_configs is None and isinstance(model, str): from sagemaker.core.jumpstart.utils import get_jumpstart_configs - + self._metadata_configs = get_jumpstart_configs( region=self.region, model_id=model, - model_version=getattr(self, 'model_version', None) or "*", - sagemaker_session=getattr(self, 'sagemaker_session', None), + model_version=getattr(self, "model_version", None) or "*", + sagemaker_session=getattr(self, "sagemaker_session", None), ) - + def _user_agent_decorator(self, func): """Decorator to add ModelBuilder to user agent string. - + Args: func: Function to decorate - + Returns: Decorated function that appends ModelBuilder to user agent. """ + def wrapper(*args, **kwargs): # Call the original function result = func(*args, **kwargs) if "ModelBuilder" in result: return result return result + " ModelBuilder" + return wrapper def _get_serve_setting(self) -> _ServeSettings: """Get serve settings for model deployment. - + Creates or uses existing S3 model data URL and constructs serve settings with deployment configuration. - + Returns: ServeSettings object with deployment configuration. """ - s3_model_data_url = getattr(self, 's3_model_data_url', None) + s3_model_data_url = getattr(self, "s3_model_data_url", None) if not s3_model_data_url: - sagemaker_session = getattr(self, 'sagemaker_session', None) + sagemaker_session = getattr(self, "sagemaker_session", None) if sagemaker_session: bucket = sagemaker_session.default_bucket() - model_name = getattr(self, 'model_name', None) + model_name = getattr(self, "model_name", None) prefix = f"model-builder/{model_name or 'model'}/{uuid.uuid4().hex}" self.s3_model_data_url = f"s3://{bucket}/{prefix}/" - + return _ServeSettings( - role_arn=getattr(self, 'role_arn', None), - s3_model_data_url=getattr(self, 's3_model_data_url', None), - instance_type=getattr(self, 'instance_type', None), - env_vars=getattr(self, 'env_vars', None), - sagemaker_session=getattr(self, 'sagemaker_session', None), + role_arn=getattr(self, "role_arn", None), + s3_model_data_url=getattr(self, "s3_model_data_url", None), + instance_type=getattr(self, "instance_type", None), + env_vars=getattr(self, "env_vars", None), + sagemaker_session=getattr(self, "sagemaker_session", None), ) - def _is_jumpstart_model_id(self) -> bool: """Check if model is a JumpStart model ID.""" - if not hasattr(self, '_cached_is_jumpstart'): + if not hasattr(self, "_cached_is_jumpstart"): if self.model is None: self._cached_is_jumpstart = False return self._cached_is_jumpstart @@ -2863,7 +2858,6 @@ def _is_jumpstart_model_id(self) -> bool: return self._cached_is_jumpstart return self._cached_is_jumpstart - def _has_nvidia_gpu(self) -> bool: try: @@ -2873,7 +2867,7 @@ def _has_nvidia_gpu(self) -> bool: # for nvidia-smi to run, a cuda driver must be present logger.debug("CUDA not found, launching Triton in CPU mode.") return False - + def _is_gpu_instance(self, instance_type: str) -> bool: instance_family = instance_type.rsplit(".", 1)[0] return instance_family in GPU_INSTANCE_FAMILIES @@ -2883,21 +2877,18 @@ def _save_inference_spec(self) -> None: if self.inference_spec: pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model") save_pkl(pkl_path, (self.inference_spec, self.schema_builder)) - - def _hmac_signing(self): - """Perform HMAC signing on picke file for integrity check""" - secret_key = generate_secret_key() + + def _compute_integrity_hash(self): + """Compute SHA-256 hash of serve.pkl and store in metadata.json for integrity check.""" pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model") with open(str(pkl_path.joinpath("serve.pkl")), "rb") as f: buffer = f.read() - hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + hash_value = compute_hash(buffer=buffer) with open(str(pkl_path.joinpath("metadata.json")), "wb") as metadata: metadata.write(_MetaData(hash_value).to_json()) - self.secret_key = secret_key - def _generate_config_pbtxt(self, pkl_path: Path): """Generate Triton config.pbtxt file.""" config_path = pkl_path.joinpath("config.pbtxt") @@ -2924,6 +2915,7 @@ def _pack_conda_env(self, pkl_path: Path): """Pack conda environment for Triton deployment.""" try: import conda_pack + conda_pack.__version__ except ModuleNotFoundError: raise ImportError( @@ -2991,11 +2983,12 @@ def _export_pytorch_to_onnx( "And follow the ones that match your environment. " "Please note that you may need to restart your runtime after installation." ) - + def _validate_for_triton(self): """Validation for Triton deployment.""" try: import tritonclient.http as httpClient + httpClient.__class__ except ModuleNotFoundError: raise ImportError( @@ -3075,7 +3068,8 @@ def _prepare_for_triton(self): export_path.mkdir(parents=True) if self.model: - self.secret_key = "dummy secret key for onnx backend" + # ONNX path: no pickle serialization, no serve.pkl, no integrity check needed. + # Do not set secret_key — there is nothing to sign. if self.framework == Framework.PYTORCH: self._export_pytorch_to_onnx( @@ -3099,12 +3093,11 @@ def _prepare_for_triton(self): self._pack_conda_env(pkl_path=pkl_path) - self._hmac_signing() + self._compute_integrity_hash() return raise ValueError("Either model or inference_spec should be provided to ModelBuilder.") - def _auto_detect_image_for_triton(self): """Detect image of triton given framework, version and region. @@ -3153,7 +3146,6 @@ def _auto_detect_image_for_triton(self): self.image_uri += "-cpu" logger.debug(f"Autodetected image: {self.image_uri}. Proceeding with the deployment.") - def _validate_djl_serving_sample_data(self): """Validate sample data format for DJL serving.""" @@ -3169,7 +3161,7 @@ def _validate_djl_serving_sample_data(self): or "generated_text" not in sample_output[0] ): raise ValueError(_INVALID_DJL_SAMPLE_DATA_EX) - + def _validate_tgi_serving_sample_data(self): """Validate sample data format for TGI serving.""" sample_input = self.schema_builder.sample_input @@ -3184,7 +3176,7 @@ def _validate_tgi_serving_sample_data(self): or "generated_text" not in sample_output[0] ): raise ValueError(_INVALID_TGI_SAMPLE_DATA_EX) - + def _create_conda_env(self): """Create conda environment by running commands.""" try: @@ -3192,13 +3184,16 @@ def _create_conda_env(self): except subprocess.CalledProcessError: logger.error("Failed to create and activate conda environment.") - - def _extract_framework_from_model_trainer(self, model_trainer: ModelTrainer) -> Optional[Framework]: + def _extract_framework_from_model_trainer( + self, model_trainer: ModelTrainer + ) -> Optional[Framework]: """Extract framework from ModelTrainer training image.""" training_image = model_trainer.training_image if not training_image: - training_image = model_trainer._latest_training_job.algorithm_specification.training_image - + training_image = ( + model_trainer._latest_training_job.algorithm_specification.training_image + ) + if "pytorch" in training_image.lower(): return Framework.PYTORCH elif "tensorflow" in training_image.lower(): @@ -3207,15 +3202,16 @@ def _extract_framework_from_model_trainer(self, model_trainer: ModelTrainer) -> return Framework.HUGGINGFACE elif "xgboost" in training_image.lower(): return Framework.XGBOOST - - return None + return None - def _infer_model_server_from_training(self, model_trainer: ModelTrainer) -> Optional[ModelServer]: + def _infer_model_server_from_training( + self, model_trainer: ModelTrainer + ) -> Optional[ModelServer]: """Infer the best model server based on training configuration.""" training_image = model_trainer.training_image framework = self._extract_framework_from_model_trainer(model_trainer) - + if "huggingface" in training_image.lower(): hyperparams = model_trainer.hyperparameters or {} if any(key in hyperparams for key in ["max_new_tokens", "do_sample", "temperature"]): @@ -3224,29 +3220,30 @@ def _infer_model_server_from_training(self, model_trainer: ModelTrainer) -> Opti else: logger.info("Auto-detected model server: MMS (HuggingFace)") return ModelServer.MMS - + if framework == Framework.PYTORCH: logger.info("Auto-detected model server: TORCHSERVE (PyTorch framework)") return ModelServer.TORCHSERVE - + if framework == Framework.TENSORFLOW: logger.info("Auto-detected model server: TENSORFLOW_SERVING (TensorFlow framework)") return ModelServer.TENSORFLOW_SERVING - + logger.warning( f"Could not auto-detect model server for framework: {framework}. " "Defaulting to TORCHSERVE. Consider explicitly setting model_server parameter." ) return ModelServer.TORCHSERVE - - def _extract_inference_spec_from_training_code(self, model_trainer: ModelTrainer) -> Optional[str]: + def _extract_inference_spec_from_training_code( + self, model_trainer: ModelTrainer + ) -> Optional[str]: """Check if training source code contains inference.py.""" if not model_trainer.source_code or not model_trainer.source_code.source_dir: return None - + source_dir = model_trainer.source_code.source_dir - + # Check for inference.py in source directory if source_dir.startswith("s3://"): pass @@ -3254,59 +3251,65 @@ def _extract_inference_spec_from_training_code(self, model_trainer: ModelTrainer inference_path = os.path.join(source_dir, "inference.py") if os.path.exists(inference_path): return inference_path - + return None - def _inherit_training_environment(self, model_trainer: ModelTrainer) -> Dict[str, str]: """Inherit relevant environment variables from training.""" from sagemaker.core.utils.utils import Unassigned - + training_env = model_trainer.environment or {} if isinstance(training_env, Unassigned): training_env = {} - + training_job_env = model_trainer._latest_training_job.environment if isinstance(training_job_env, Unassigned) or training_job_env is None: training_job_env = {} - + inherited_env = {**training_env, **training_job_env} inference_relevant_keys = [ - "HUGGING_FACE_HUB_TOKEN", "HF_TOKEN", - "MODEL_CLASS_NAME", "TRANSFORMERS_CACHE", - "PYTORCH_TRANSFORMERS_CACHE", "HF_HOME" + "HUGGING_FACE_HUB_TOKEN", + "HF_TOKEN", + "MODEL_CLASS_NAME", + "TRANSFORMERS_CACHE", + "PYTORCH_TRANSFORMERS_CACHE", + "HF_HOME", ] - - return {k: v for k, v in inherited_env.items() - if k in inference_relevant_keys or k.startswith("SAGEMAKER_")} - + + return { + k: v + for k, v in inherited_env.items() + if k in inference_relevant_keys or k.startswith("SAGEMAKER_") + } def _extract_version_from_training_image(self, training_image: str) -> Optional[str]: """Extract framework version from training image URI.""" import re - - version_match = re.search(r':(\d+\.\d+(?:\.\d+)?)', training_image) + + version_match = re.search(r":(\d+\.\d+(?:\.\d+)?)", training_image) if version_match: return version_match.group(1) - - return None + return None def _detect_inference_image_from_training(self) -> None: """Detect inference image based on ModelTrainer's training image.""" from sagemaker.core import image_uris + training_image = self.model.training_image - + if "pytorch-training" in training_image: self.image_uri = training_image.replace("pytorch-training", "pytorch-inference") elif "tensorflow-training" in training_image: self.image_uri = training_image.replace("tensorflow-training", "tensorflow-inference") elif "huggingface-pytorch-training" in training_image: - self.image_uri = training_image.replace("huggingface-pytorch-training", "huggingface-pytorch-inference") + self.image_uri = training_image.replace( + "huggingface-pytorch-training", "huggingface-pytorch-inference" + ) else: framework = self._extract_framework_from_model_trainer(self.model) fw = framework.value.lower() if framework else "pytorch" - + fw_version = self._extract_version_from_training_image(training_image) py_tuple = platform.python_version_tuple() casted_versions = _cast_to_compatible_version(fw, fw_version) if fw_version else (None,) @@ -3325,12 +3328,13 @@ def _detect_inference_image_from_training(self) -> None: break except ValueError: pass - + if dlc: self.image_uri = dlc else: - raise ValueError(f"Could not detect inference image for training image: {training_image}") - + raise ValueError( + f"Could not detect inference image for training image: {training_image}" + ) def _extract_speculative_draft_model_provider( self, @@ -3359,9 +3363,10 @@ def _extract_speculative_draft_model_provider( return "sagemaker" return "auto" - - def get_huggingface_model_metadata(self, model_id: str, hf_hub_token: Optional[str] = None) -> dict: + def get_huggingface_model_metadata( + self, model_id: str, hf_hub_token: Optional[str] = None + ) -> dict: """Retrieves the json metadata of the HuggingFace Model via HuggingFace API. Args: @@ -3405,7 +3410,6 @@ def get_huggingface_model_metadata(self, model_id: str, hf_hub_token: Optional[s ) return hf_model_metadata_json - def download_huggingface_model_metadata( self, model_id: str, model_local_path: str, hf_hub_token: Optional[str] = None ) -> None: @@ -3420,11 +3424,12 @@ def download_huggingface_model_metadata( ImportError: If huggingface_hub is not installed. """ if not importlib.util.find_spec("huggingface_hub"): - raise ImportError("Unable to import huggingface_hub, check if huggingface_hub is installed") + raise ImportError( + "Unable to import huggingface_hub, check if huggingface_hub is installed" + ) from huggingface_hub import snapshot_download os.makedirs(model_local_path, exist_ok=True) logger.info("Downloading model %s from Hugging Face Hub to %s", model_id, model_local_path) snapshot_download(repo_id=model_id, local_dir=model_local_path, token=hf_hub_token) - diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/prepare.py b/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/prepare.py index 37ca745987..3b347ee65c 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/prepare.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/prepare.py @@ -26,7 +26,6 @@ from sagemaker.serve.spec.inference_spec import InferenceSpec from sagemaker.serve.detector.dependency_manager import capture_dependencies from sagemaker.serve.validations.check_integrity import ( - generate_secret_key, compute_hash, ) from sagemaker.core.remote_function.core.serialization import _MetaData @@ -119,11 +118,8 @@ def prepare_for_mms( capture_dependencies(dependencies=dependencies, work_dir=code_dir) - secret_key = generate_secret_key() with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: buffer = f.read() - hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + hash_value = compute_hash(buffer=buffer) with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: metadata.write(_MetaData(hash_value).to_json()) - - return secret_key diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/server.py index 9401dd74d9..1e02be0621 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/server.py @@ -35,7 +35,6 @@ def _start_serving( env = { "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", "SAGEMAKER_PROGRAM": "inference.py", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } if env_vars: @@ -47,7 +46,7 @@ def _start_serving( image, "serve", # network_mode="host", - ports={'8080/tcp': 8080}, + ports={"8080/tcp": 8080}, detach=True, auto_remove=True, volumes={ @@ -131,7 +130,6 @@ def _upload_server_artifacts( env_vars = { "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", "SAGEMAKER_PROGRAM": "inference.py", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "SAGEMAKER_REGION": sagemaker_session.boto_region_name, "SAGEMAKER_CONTAINER_LOG_LEVEL": "10", "LOCAL_PYTHON": platform.python_version(), diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/smd/prepare.py b/sagemaker-serve/src/sagemaker/serve/model_server/smd/prepare.py index b66de32bf7..f29b8ebcbd 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/smd/prepare.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/smd/prepare.py @@ -12,7 +12,6 @@ from sagemaker.serve.spec.inference_spec import InferenceSpec from sagemaker.serve.detector.dependency_manager import capture_dependencies from sagemaker.serve.validations.check_integrity import ( - generate_secret_key, compute_hash, ) from sagemaker.core.remote_function.core.serialization import _MetaData @@ -64,11 +63,8 @@ def prepare_for_smd( capture_dependencies(dependencies=dependencies, work_dir=code_dir) - secret_key = generate_secret_key() with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: buffer = f.read() - hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + hash_value = compute_hash(buffer=buffer) with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: metadata.write(_MetaData(hash_value).to_json()) - - return secret_key diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/smd/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/smd/server.py index e40dc3aa61..ecb68406c1 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/smd/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/smd/server.py @@ -53,7 +53,6 @@ def _upload_smd_artifacts( "SAGEMAKER_INFERENCE_CODE_DIRECTORY": "/opt/ml/model/code", "SAGEMAKER_INFERENCE_CODE": "inference.handler", "SAGEMAKER_REGION": sagemaker_session.boto_region_name, - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } return s3_upload_path, env_vars diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/tei/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/tei/server.py index 9f2f4b71b3..c23c52a513 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/tei/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/tei/server.py @@ -38,8 +38,6 @@ def _start_tei_serving( secret_key: Secret key to use for authentication env_vars: Environment variables to set """ - if env_vars and secret_key: - env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key self.container = client.containers.run( image, diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py b/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py index 3525cc9b4a..d56d0ec7bd 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py @@ -11,7 +11,6 @@ ) from sagemaker.serve.detector.dependency_manager import capture_dependencies from sagemaker.serve.validations.check_integrity import ( - generate_secret_key, compute_hash, ) from sagemaker.core.remote_function.core.serialization import _MetaData @@ -56,12 +55,9 @@ def prepare_for_tf_serving( if not mlflow_saved_model_dir: raise ValueError("SavedModel is not found for Tensorflow or Keras flavor.") _move_contents(src_dir=mlflow_saved_model_dir, dest_dir=saved_model_bundle_dir) - - secret_key = generate_secret_key() + with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: buffer = f.read() - hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + hash_value = compute_hash(buffer=buffer) with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: metadata.write(_MetaData(hash_value).to_json()) - - return secret_key diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/server.py index 2f4a959528..cbd6412d34 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/server.py @@ -37,7 +37,7 @@ def _start_tensorflow_serving( detach=True, auto_remove=False, # Temporarily disabled to see crash logs # network_mode="host", - ports={'8501/tcp': 8501}, + ports={"8501/tcp": 8501}, volumes={ Path(model_path): { "bind": "/opt/ml/model", @@ -47,7 +47,6 @@ def _start_tensorflow_serving( environment={ "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", "SAGEMAKER_PROGRAM": "inference.py", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), **env_vars, }, @@ -124,7 +123,6 @@ def _upload_tensorflow_serving_artifacts( "SAGEMAKER_PROGRAM": "inference.py", "SAGEMAKER_REGION": sagemaker_session.boto_region_name, "SAGEMAKER_CONTAINER_LOG_LEVEL": "10", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } return s3_upload_path, env_vars diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/prepare.py b/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/prepare.py index 988acf646d..ad053d25c9 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/prepare.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/prepare.py @@ -13,7 +13,6 @@ from sagemaker.serve.spec.inference_spec import InferenceSpec from sagemaker.serve.detector.dependency_manager import capture_dependencies from sagemaker.serve.validations.check_integrity import ( - generate_secret_key, compute_hash, ) from sagemaker.serve.validations.check_image_uri import is_1p_image_uri @@ -56,7 +55,9 @@ def prepare_for_torchserve( # https://github.com/aws/sagemaker-python-sdk/issues/4288 if is_1p_image_uri(image_uri=image_uri) and "xgboost" in image_uri: shutil.copy2(Path(__file__).parent.joinpath("xgboost_inference.py"), code_dir) - os.rename(str(code_dir.joinpath("xgboost_inference.py")), str(code_dir.joinpath("inference.py"))) + os.rename( + str(code_dir.joinpath("xgboost_inference.py")), str(code_dir.joinpath("inference.py")) + ) else: shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir) @@ -67,11 +68,8 @@ def prepare_for_torchserve( capture_dependencies(dependencies=dependencies, work_dir=code_dir) - secret_key = generate_secret_key() with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: buffer = f.read() - hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + hash_value = compute_hash(buffer=buffer) with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: metadata.write(_MetaData(hash_value).to_json()) - - return secret_key \ No newline at end of file diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/server.py index 0d237df987..9cc4e6196f 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/torchserve/server.py @@ -29,7 +29,7 @@ def _start_torch_serve( detach=True, auto_remove=True, # network_mode="host", - ports={'8080/tcp': 8080}, + ports={"8080/tcp": 8080}, volumes={ Path(model_path): { "bind": "/opt/ml/model", @@ -39,7 +39,6 @@ def _start_torch_serve( environment={ "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", "SAGEMAKER_PROGRAM": "inference.py", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), **env_vars, }, @@ -103,7 +102,6 @@ def _upload_torchserve_artifacts( "SAGEMAKER_PROGRAM": "inference.py", "SAGEMAKER_REGION": sagemaker_session.boto_region_name, "SAGEMAKER_CONTAINER_LOG_LEVEL": "10", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } return s3_upload_path, env_vars diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py b/sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py index a1c731b0d6..7d49b0723d 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py @@ -26,10 +26,14 @@ def auto_complete_config(auto_complete_model_config): def initialize(self, args: dict) -> None: """Placeholder docstring""" serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl") - with open(str(serve_path), mode="rb") as f: - inference_spec, schema_builder = cloudpickle.load(f) + metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json") - # TODO: HMAC signing for integrity check + # Integrity check BEFORE deserialization to prevent RCE via malicious pickle + with open(str(serve_path), "rb") as f: + buffer = f.read() + perform_integrity_check(buffer=buffer, metadata_path=metadata_path) + + inference_spec, schema_builder = cloudpickle.loads(buffer) self.inference_spec = inference_spec self.schema_builder = schema_builder diff --git a/sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py b/sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py index 134f12dd42..b425f8a689 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py +++ b/sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py @@ -41,7 +41,6 @@ def _start_triton_server( env_vars.update( { "TRITON_MODEL_DIR": "/models/model", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } ) @@ -133,7 +132,6 @@ def _upload_triton_artifacts( env_vars = { "SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model", "TRITON_MODEL_DIR": "/opt/ml/model/model", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } return s3_upload_path, env_vars diff --git a/sagemaker-serve/src/sagemaker/serve/validations/check_integrity.py b/sagemaker-serve/src/sagemaker/serve/validations/check_integrity.py index 4363d8d6ed..880ca5b602 100644 --- a/sagemaker-serve/src/sagemaker/serve/validations/check_integrity.py +++ b/sagemaker-serve/src/sagemaker/serve/validations/check_integrity.py @@ -1,29 +1,21 @@ -"""Validates the integrity of pickled file with HMAC signing.""" +"""Validates the integrity of pickled file with SHA-256 hash.""" from __future__ import absolute_import -import secrets import hmac import hashlib -import os from pathlib import Path from sagemaker.core.remote_function.core.serialization import _MetaData -def generate_secret_key(nbytes: int = 32) -> str: - """Generates secret key""" - return secrets.token_hex(nbytes) - - -def compute_hash(buffer: bytes, secret_key: str) -> str: - """Compute hash value using HMAC""" - return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest() +def compute_hash(buffer: bytes) -> str: + """Compute SHA-256 hash of the given buffer.""" + return hashlib.sha256(buffer).hexdigest() def perform_integrity_check(buffer: bytes, metadata_path: Path): - """Validates the integrity of bytes by comparing the hash value""" - secret_key = os.environ.get("SAGEMAKER_SERVE_SECRET_KEY") - actual_hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + """Validates the integrity of bytes by comparing the hash value.""" + actual_hash_value = compute_hash(buffer=buffer) if not Path.exists(metadata_path): raise ValueError("Path to metadata.json does not exist") diff --git a/sagemaker-serve/tests/unit/model_server/test_djl_utils.py b/sagemaker-serve/tests/unit/model_server/test_djl_utils.py index 814feb3cc4..a66f04cbf4 100644 --- a/sagemaker-serve/tests/unit/model_server/test_djl_utils.py +++ b/sagemaker-serve/tests/unit/model_server/test_djl_utils.py @@ -6,7 +6,7 @@ _get_default_batch_size, _tokens_from_chars, _tokens_from_words, - _set_tokens_to_tokens_threshold + _set_tokens_to_tokens_threshold, ) diff --git a/sagemaker-serve/tests/unit/model_server/test_in_process_model_server_app.py b/sagemaker-serve/tests/unit/model_server/test_in_process_model_server_app.py index deeeefc704..d9c32c7e2d 100644 --- a/sagemaker-serve/tests/unit/model_server/test_in_process_model_server_app.py +++ b/sagemaker-serve/tests/unit/model_server/test_in_process_model_server_app.py @@ -14,10 +14,10 @@ # Mock optional dependencies before importing mock_transformers = MagicMock() -mock_pipeline_class = type('Pipeline', (), {}) +mock_pipeline_class = type("Pipeline", (), {}) mock_transformers.Pipeline = mock_pipeline_class -sys.modules['transformers'] = mock_transformers -sys.modules['sentence_transformers'] = MagicMock() +sys.modules["transformers"] = mock_transformers +sys.modules["sentence_transformers"] = MagicMock() from sagemaker.serve.model_server.in_process_model_server.app import InProcessServer @@ -25,22 +25,22 @@ class TestInProcessServerInitialization(unittest.TestCase): """Test InProcessServer initialization.""" - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_init_with_inference_spec(self, mock_fastapi, mock_uvicorn): """Test initialization with inference_spec.""" mock_inference_spec = Mock() mock_model = Mock() mock_inference_spec.load.return_value = mock_model mock_schema_builder = Mock() - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, schema_builder=mock_schema_builder, - task="text-generation" + task="text-generation", ) - + self.assertEqual(server.model, "test-model") self.assertEqual(server.inference_spec, mock_inference_spec) self.assertEqual(server.schema_builder, mock_schema_builder) @@ -62,19 +62,19 @@ def test_init_fallback_to_sentence_transformer(self): # This test requires sentence-transformers to be installed self.skipTest("Requires sentence-transformers package") - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_init_without_model_or_inference_spec_raises_error(self, mock_fastapi, mock_uvicorn): """Test that initialization without model or inference_spec raises ValueError.""" mock_schema_builder = Mock() - + with self.assertRaises(ValueError) as context: InProcessServer(schema_builder=mock_schema_builder) - + self.assertIn("Either inference_spec or model must be provided", str(context.exception)) - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_create_server_configuration(self, mock_fastapi, mock_uvicorn): """Test that server is created with correct configuration.""" mock_inference_spec = Mock() @@ -88,24 +88,24 @@ def test_create_server_configuration(self, mock_fastapi, mock_uvicorn): mock_uvicorn.Config.return_value = mock_config mock_server = Mock() mock_uvicorn.Server.return_value = mock_server - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - + # Verify FastAPI app was created mock_fastapi.assert_called_once() mock_app.include_router.assert_called_once() - + # Verify uvicorn config mock_uvicorn.Config.assert_called_once() config_call_args = mock_uvicorn.Config.call_args - self.assertEqual(config_call_args[1]['host'], "127.0.0.1") - self.assertEqual(config_call_args[1]['port'], 9007) - self.assertEqual(config_call_args[1]['log_level'], "info") - + self.assertEqual(config_call_args[1]["host"], "127.0.0.1") + self.assertEqual(config_call_args[1]["port"], 9007) + self.assertEqual(config_call_args[1]["log_level"], "info") + # Verify server attributes self.assertEqual(server.host, "127.0.0.1") self.assertEqual(server.port, 9007) @@ -115,38 +115,39 @@ def test_create_server_configuration(self, mock_fastapi, mock_uvicorn): class TestInProcessServerInvokeEndpoint(unittest.TestCase): """Test InProcessServer /invoke endpoint.""" - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_invoke_with_inference_spec(self, mock_fastapi, mock_uvicorn): """Test /invoke endpoint with inference_spec.""" mock_inference_spec = Mock() mock_model = Mock() mock_inference_spec.load.return_value = mock_model mock_inference_spec.invoke.return_value = {"predictions": [0.1, 0.9]} - + mock_schema_builder = Mock() mock_deserializer = Mock() mock_deserializer.deserialize.return_value = {"inputs": [[1, 2, 3]]} mock_schema_builder.input_deserializer = mock_deserializer - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - + # Simulate request mock_request = AsyncMock() mock_request.headers = {"Content-Type": ["application/json"]} mock_request.body = AsyncMock(return_value=b'{"inputs": [[1, 2, 3]]}') - + # Get the invoke function from the router invoke_func = server._router.routes[0].endpoint - + # Run async function import asyncio + result = asyncio.run(invoke_func(mock_request)) - + self.assertEqual(result, {"predictions": [0.1, 0.9]}) mock_inference_spec.invoke.assert_called_once_with({"inputs": [[1, 2, 3]]}, mock_model) @@ -166,51 +167,51 @@ def test_invoke_with_sentence_transformer(self): class TestInProcessServerLifecycle(unittest.TestCase): """Test InProcessServer lifecycle methods.""" - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_start_server(self, mock_fastapi, mock_uvicorn): """Test starting the server.""" mock_inference_spec = Mock() mock_inference_spec.load.return_value = Mock() mock_schema_builder = Mock() - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - - with patch.object(threading.Thread, 'start') as mock_thread_start: + + with patch.object(threading.Thread, "start") as mock_thread_start: server.start_server() mock_thread_start.assert_called_once() self.assertIsNotNone(server._thread) - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_start_server_when_already_running(self, mock_fastapi, mock_uvicorn): """Test starting server when it's already running.""" mock_inference_spec = Mock() mock_inference_spec.load.return_value = Mock() mock_schema_builder = Mock() - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - + # Mock thread as already running mock_thread = Mock() mock_thread.is_alive.return_value = True server._thread = mock_thread - - with patch.object(threading.Thread, 'start') as mock_thread_start: + + with patch.object(threading.Thread, "start") as mock_thread_start: server.start_server() # Should not start a new thread mock_thread_start.assert_not_called() - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_stop_server(self, mock_fastapi, mock_uvicorn): """Test stopping the server.""" mock_inference_spec = Mock() @@ -218,26 +219,26 @@ def test_stop_server(self, mock_fastapi, mock_uvicorn): mock_schema_builder = Mock() mock_server = Mock() mock_uvicorn.Server.return_value = mock_server - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - + # Mock thread as running mock_thread = Mock() mock_thread.is_alive.return_value = True server._thread = mock_thread - + server.stop_server() - + self.assertTrue(server._shutdown_event.is_set()) mock_server.handle_exit.assert_called_once_with(sig=0, frame=None) mock_thread.join.assert_called_once() - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") def test_stop_server_when_not_running(self, mock_fastapi, mock_uvicorn): """Test stopping server when it's not running.""" mock_inference_spec = Mock() @@ -245,40 +246,40 @@ def test_stop_server_when_not_running(self, mock_fastapi, mock_uvicorn): mock_schema_builder = Mock() mock_server = Mock() mock_uvicorn.Server.return_value = mock_server - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - + # No thread or thread not alive server._thread = None - + # Should not raise error server.stop_server() mock_server.handle_exit.assert_not_called() - @patch('sagemaker.serve.model_server.in_process_model_server.app.uvicorn') - @patch('sagemaker.serve.model_server.in_process_model_server.app.FastAPI') - @patch('sagemaker.serve.model_server.in_process_model_server.app.asyncio') + @patch("sagemaker.serve.model_server.in_process_model_server.app.uvicorn") + @patch("sagemaker.serve.model_server.in_process_model_server.app.FastAPI") + @patch("sagemaker.serve.model_server.in_process_model_server.app.asyncio") def test_start_run_async_in_thread(self, mock_asyncio, mock_fastapi, mock_uvicorn): """Test _start_run_async_in_thread method.""" mock_inference_spec = Mock() mock_inference_spec.load.return_value = Mock() mock_schema_builder = Mock() - + server = InProcessServer( model="test-model", inference_spec=mock_inference_spec, - schema_builder=mock_schema_builder + schema_builder=mock_schema_builder, ) - + mock_loop = Mock() mock_asyncio.new_event_loop.return_value = mock_loop - + server._start_run_async_in_thread() - + mock_asyncio.new_event_loop.assert_called_once() mock_asyncio.set_event_loop.assert_called_once_with(mock_loop) mock_loop.run_until_complete.assert_called_once() @@ -300,5 +301,5 @@ def test_invoke_without_inputs_key(self): self.skipTest("Requires transformers package") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sagemaker-serve/tests/unit/model_server/test_multi_model_server_inference.py b/sagemaker-serve/tests/unit/model_server/test_multi_model_server_inference.py index 5842bf0f8d..34c0fa671f 100644 --- a/sagemaker-serve/tests/unit/model_server/test_multi_model_server_inference.py +++ b/sagemaker-serve/tests/unit/model_server/test_multi_model_server_inference.py @@ -13,128 +13,161 @@ class TestMultiModelServerInference(unittest.TestCase): def test_predict_fn_logic(self): """Test predict_fn logic.""" + def predict_fn(input_data, predict_callable, context=None): return predict_callable(input_data) - + mock_predict_callable = Mock(return_value=[0.1, 0.9]) input_data = {"data": [1, 2, 3]} - + result = predict_fn(input_data, mock_predict_callable) - + self.assertEqual(result, [0.1, 0.9]) mock_predict_callable.assert_called_once_with(input_data) def test_input_fn_with_preprocess_logic(self): """Test input_fn with preprocess logic.""" + def input_fn(input_data, content_type, schema_builder, inference_spec, context=None): # Deserialize if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data.encode("utf-8")) if isinstance(input_data, str) else io.BytesIO(input_data), + ( + io.BytesIO(input_data.encode("utf-8")) + if isinstance(input_data, str) + else io.BytesIO(input_data) + ), content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data.encode("utf-8")) if isinstance(input_data, str) else io.BytesIO(input_data), + ( + io.BytesIO(input_data.encode("utf-8")) + if isinstance(input_data, str) + else io.BytesIO(input_data) + ), content_type[0], ) - + # Preprocess if available if hasattr(inference_spec, "preprocess"): preprocessed = inference_spec.preprocess(deserialized_data) if preprocessed is not None: return preprocessed - + return deserialized_data - + schema_builder = Mock() schema_builder.custom_input_translator = Mock() schema_builder.custom_input_translator.deserialize = Mock(return_value={"data": [1, 2, 3]}) - + inference_spec = Mock() inference_spec.preprocess = Mock(return_value={"preprocessed": True}) - - result = input_fn('{"data": [1, 2, 3]}', ["application/json"], schema_builder, inference_spec) - + + result = input_fn( + '{"data": [1, 2, 3]}', ["application/json"], schema_builder, inference_spec + ) + self.assertEqual(result, {"preprocessed": True}) inference_spec.preprocess.assert_called_once_with({"data": [1, 2, 3]}) def test_input_fn_with_bytes_input_logic(self): """Test input_fn with bytes input.""" + def input_fn(input_data, content_type, schema_builder, inference_spec, context=None): if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data) if isinstance(input_data, (bytes, bytearray)) else io.BytesIO(input_data.encode("utf-8")), + ( + io.BytesIO(input_data) + if isinstance(input_data, (bytes, bytearray)) + else io.BytesIO(input_data.encode("utf-8")) + ), content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data) if isinstance(input_data, (bytes, bytearray)) else io.BytesIO(input_data.encode("utf-8")), + ( + io.BytesIO(input_data) + if isinstance(input_data, (bytes, bytearray)) + else io.BytesIO(input_data.encode("utf-8")) + ), content_type[0], ) return deserialized_data - + schema_builder = Mock() schema_builder.custom_input_translator = Mock() schema_builder.custom_input_translator.deserialize = Mock(return_value={"data": [1, 2, 3]}) - + inference_spec = None - - result = input_fn(b'{"data": [1, 2, 3]}', ["application/json"], schema_builder, inference_spec) - + + result = input_fn( + b'{"data": [1, 2, 3]}', ["application/json"], schema_builder, inference_spec + ) + self.assertEqual(result, {"data": [1, 2, 3]}) def test_output_fn_with_postprocess_logic(self): """Test output_fn with postprocess logic.""" + def output_fn(predictions, accept_type, schema_builder, inference_spec, context=None): # Postprocess if available if hasattr(inference_spec, "postprocess"): postprocessed = inference_spec.postprocess(predictions) if postprocessed is not None: predictions = postprocessed - + # Serialize if hasattr(schema_builder, "custom_output_translator"): return schema_builder.custom_output_translator.serialize(predictions, accept_type) else: return schema_builder.output_serializer.serialize(predictions) - + schema_builder = Mock() schema_builder.custom_output_translator = Mock() - schema_builder.custom_output_translator.serialize = Mock(return_value=b'{"predictions": [0.1, 0.9]}') - + schema_builder.custom_output_translator.serialize = Mock( + return_value=b'{"predictions": [0.1, 0.9]}' + ) + inference_spec = Mock() inference_spec.postprocess = Mock(return_value={"postprocessed": True}) - + result = output_fn([0.1, 0.9], "application/json", schema_builder, inference_spec) - + inference_spec.postprocess.assert_called_once_with([0.1, 0.9]) - schema_builder.custom_output_translator.serialize.assert_called_once_with({"postprocessed": True}, "application/json") + schema_builder.custom_output_translator.serialize.assert_called_once_with( + {"postprocessed": True}, "application/json" + ) def test_output_fn_postprocess_returns_none_logic(self): """Test output_fn when postprocess returns None.""" + def output_fn(predictions, accept_type, schema_builder, inference_spec, context=None): if hasattr(inference_spec, "postprocess"): postprocessed = inference_spec.postprocess(predictions) if postprocessed is not None: predictions = postprocessed - + if hasattr(schema_builder, "custom_output_translator"): return schema_builder.custom_output_translator.serialize(predictions, accept_type) else: return schema_builder.output_serializer.serialize(predictions) - + schema_builder = Mock() schema_builder.custom_output_translator = Mock() - schema_builder.custom_output_translator.serialize = Mock(return_value=b'{"predictions": [0.1, 0.9]}') - + schema_builder.custom_output_translator.serialize = Mock( + return_value=b'{"predictions": [0.1, 0.9]}' + ) + inference_spec = Mock() inference_spec.postprocess = Mock(return_value=None) - + result = output_fn([0.1, 0.9], "application/json", schema_builder, inference_spec) - + # Should use original predictions since postprocess returned None - schema_builder.custom_output_translator.serialize.assert_called_once_with([0.1, 0.9], "application/json") + schema_builder.custom_output_translator.serialize.assert_called_once_with( + [0.1, 0.9], "application/json" + ) if __name__ == "__main__": diff --git a/sagemaker-serve/tests/unit/model_server/test_multi_model_server_prepare.py b/sagemaker-serve/tests/unit/model_server/test_multi_model_server_prepare.py index d6a571cd1a..8d8f5ec9d2 100644 --- a/sagemaker-serve/tests/unit/model_server/test_multi_model_server_prepare.py +++ b/sagemaker-serve/tests/unit/model_server/test_multi_model_server_prepare.py @@ -17,147 +17,149 @@ def tearDown(self): if Path(self.temp_dir).exists(): shutil.rmtree(self.temp_dir) - @patch('sagemaker.serve.model_server.multi_model_server.prepare._check_docker_disk_usage') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._check_disk_space') + @patch("sagemaker.serve.model_server.multi_model_server.prepare._check_docker_disk_usage") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._check_disk_space") def test_create_dir_structure_creates_directories(self, mock_disk_space, mock_docker_disk): """Test _create_dir_structure creates model and code directories.""" from sagemaker.serve.model_server.multi_model_server.prepare import _create_dir_structure - + model_path = Path(self.temp_dir) / "model" model_path_obj, code_dir = _create_dir_structure(str(model_path)) - + self.assertTrue(model_path.exists()) self.assertTrue(code_dir.exists()) self.assertEqual(code_dir, model_path / "code") mock_disk_space.assert_called_once() mock_docker_disk.assert_called_once() - @patch('sagemaker.serve.model_server.multi_model_server.prepare._check_docker_disk_usage') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._check_disk_space') + @patch("sagemaker.serve.model_server.multi_model_server.prepare._check_docker_disk_usage") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._check_disk_space") def test_create_dir_structure_raises_on_file(self, mock_disk_space, mock_docker_disk): """Test _create_dir_structure raises ValueError when path is a file.""" from sagemaker.serve.model_server.multi_model_server.prepare import _create_dir_structure - + file_path = Path(self.temp_dir) / "file.txt" file_path.touch() - + with self.assertRaises(ValueError) as context: _create_dir_structure(str(file_path)) self.assertIn("not a valid directory", str(context.exception)) - @patch('sagemaker.serve.model_server.multi_model_server.prepare._copy_jumpstart_artifacts') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure') + @patch("sagemaker.serve.model_server.multi_model_server.prepare._copy_jumpstart_artifacts") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure") def test_prepare_mms_js_resources(self, mock_create_dir, mock_copy_js): """Test prepare_mms_js_resources calls necessary functions.""" from sagemaker.serve.model_server.multi_model_server.prepare import prepare_mms_js_resources - + mock_model_path = Path(self.temp_dir) / "model" mock_code_dir = mock_model_path / "code" mock_create_dir.return_value = (mock_model_path, mock_code_dir) mock_copy_js.return_value = ({"config": "data"}, True) - + result = prepare_mms_js_resources( model_path=str(mock_model_path), js_id="test-js-id", - model_data="s3://bucket/model.tar.gz" + model_data="s3://bucket/model.tar.gz", ) - + mock_create_dir.assert_called_once_with(str(mock_model_path)) - mock_copy_js.assert_called_once_with("s3://bucket/model.tar.gz", "test-js-id", mock_code_dir) + mock_copy_js.assert_called_once_with( + "s3://bucket/model.tar.gz", "test-js-id", mock_code_dir + ) self.assertEqual(result, ({"config": "data"}, True)) - @patch('builtins.input', return_value='') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.compute_hash') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_mms_creates_structure(self, mock_copy, mock_capture, mock_gen_key, mock_hash, mock_input): + @patch("builtins.input", return_value="") + @patch("sagemaker.serve.model_server.multi_model_server.prepare.compute_hash") + @patch("sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_mms_creates_structure( + self, mock_copy, mock_capture, mock_hash, mock_input + ): """Test prepare_for_mms creates directory structure and files.""" from sagemaker.serve.model_server.multi_model_server.prepare import prepare_for_mms - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + # Create serve.pkl file serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - - mock_gen_key.return_value = "test-secret-key" + mock_hash.return_value = "test-hash" mock_session = Mock() mock_inference_spec = Mock() - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_mms( model_path=str(model_path), shared_libs=[], dependencies={}, session=mock_session, image_uri="test-image", - inference_spec=mock_inference_spec + inference_spec=mock_inference_spec, ) - - self.assertEqual(secret_key, "test-secret-key") + mock_inference_spec.prepare.assert_called_once_with(str(model_path)) mock_capture.assert_called_once() - @patch('builtins.input', return_value='') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.compute_hash') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_mms_raises_on_invalid_dir(self, mock_copy, mock_capture, mock_gen_key, mock_hash, mock_input): + @patch("builtins.input", return_value="") + @patch("sagemaker.serve.model_server.multi_model_server.prepare.compute_hash") + @patch("sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_mms_raises_on_invalid_dir( + self, mock_copy, mock_capture, mock_hash, mock_input + ): """Test prepare_for_mms raises exception for invalid directory.""" from sagemaker.serve.model_server.multi_model_server.prepare import prepare_for_mms - + file_path = Path(self.temp_dir) / "file.txt" file_path.touch() - + mock_session = Mock() - + with self.assertRaises(Exception) as context: prepare_for_mms( model_path=str(file_path), shared_libs=[], dependencies={}, session=mock_session, - image_uri="test-image" + image_uri="test-image", ) self.assertIn("not a valid directory", str(context.exception)) - @patch('builtins.input', return_value='') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.compute_hash') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_mms_copies_shared_libs(self, mock_copy, mock_capture, mock_gen_key, mock_hash, mock_input): + @patch("builtins.input", return_value="") + @patch("sagemaker.serve.model_server.multi_model_server.prepare.compute_hash") + @patch("sagemaker.serve.model_server.multi_model_server.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_mms_copies_shared_libs( + self, mock_copy, mock_capture, mock_hash, mock_input + ): """Test prepare_for_mms copies shared libraries.""" from sagemaker.serve.model_server.multi_model_server.prepare import prepare_for_mms - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + shared_lib = Path(self.temp_dir) / "lib.so" shared_lib.touch() - - mock_gen_key.return_value = "test-key" + mock_hash.return_value = "test-hash" mock_session = Mock() - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): prepare_for_mms( model_path=str(model_path), shared_libs=[str(shared_lib)], dependencies={}, session=mock_session, - image_uri="test-image" + image_uri="test-image", ) - + # Verify copy2 was called for shared lib self.assertTrue(any(str(shared_lib) in str(call) for call in mock_copy.call_args_list)) diff --git a/sagemaker-serve/tests/unit/model_server/test_multi_model_server_server.py b/sagemaker-serve/tests/unit/model_server/test_multi_model_server_server.py index 02ae4dc596..a19c808264 100644 --- a/sagemaker-serve/tests/unit/model_server/test_multi_model_server_server.py +++ b/sagemaker-serve/tests/unit/model_server/test_multi_model_server_server.py @@ -8,97 +8,94 @@ class TestLocalMultiModelServer(unittest.TestCase): """Test LocalMultiModelServer class.""" - @patch('sagemaker.serve.model_server.multi_model_server.server.Path') + @patch("sagemaker.serve.model_server.multi_model_server.server.Path") def test_start_serving_creates_container(self, mock_path): """Test _start_serving creates and configures container.""" from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer - + server = LocalMultiModelServer() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value.joinpath.return_value = mock_path_obj - + server._start_serving( client=mock_client, image="test-image:latest", model_path="/path/to/model", secret_key="test-secret", - env_vars={"CUSTOM_VAR": "value"} + env_vars={"CUSTOM_VAR": "value"}, ) - + self.assertEqual(server.container, mock_container) mock_client.containers.run.assert_called_once() call_kwargs = mock_client.containers.run.call_args[1] - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", call_kwargs["environment"]) - self.assertEqual(call_kwargs["environment"]["SAGEMAKER_SERVE_SECRET_KEY"], "test-secret") + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", call_kwargs["environment"]) - @patch('sagemaker.serve.model_server.multi_model_server.server.Path') + @patch("sagemaker.serve.model_server.multi_model_server.server.Path") def test_start_serving_with_no_env_vars(self, mock_path): """Test _start_serving with no custom env vars.""" from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer - + server = LocalMultiModelServer() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value.joinpath.return_value = mock_path_obj - + server._start_serving( client=mock_client, image="test-image:latest", model_path="/path/to/model", secret_key="test-secret", - env_vars=None + env_vars=None, ) - + call_kwargs = mock_client.containers.run.call_args[1] self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", call_kwargs["environment"]) self.assertIn("SAGEMAKER_PROGRAM", call_kwargs["environment"]) - @patch('sagemaker.serve.model_server.multi_model_server.server.requests.post') - @patch('sagemaker.serve.model_server.multi_model_server.server.get_docker_host') + @patch("sagemaker.serve.model_server.multi_model_server.server.requests.post") + @patch("sagemaker.serve.model_server.multi_model_server.server.get_docker_host") def test_invoke_multi_model_server_serving_success(self, mock_get_host, mock_post): """Test _invoke_multi_model_server_serving successful request.""" from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer - + server = LocalMultiModelServer() mock_get_host.return_value = "localhost" mock_response = Mock() mock_response.content = b'{"result": "success"}' mock_post.return_value = mock_response - + result = server._invoke_multi_model_server_serving( - request='{"input": "data"}', - content_type="application/json", - accept="application/json" + request='{"input": "data"}', content_type="application/json", accept="application/json" ) - + self.assertEqual(result, b'{"result": "success"}') mock_post.assert_called_once() call_kwargs = mock_post.call_args[1] self.assertEqual(call_kwargs["headers"]["Content-Type"], "application/json") self.assertEqual(call_kwargs["headers"]["Accept"], "application/json") - @patch('sagemaker.serve.model_server.multi_model_server.server.requests.post') - @patch('sagemaker.serve.model_server.multi_model_server.server.get_docker_host') + @patch("sagemaker.serve.model_server.multi_model_server.server.requests.post") + @patch("sagemaker.serve.model_server.multi_model_server.server.get_docker_host") def test_invoke_multi_model_server_serving_failure(self, mock_get_host, mock_post): """Test _invoke_multi_model_server_serving handles errors.""" from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer - + server = LocalMultiModelServer() mock_get_host.return_value = "localhost" mock_post.side_effect = Exception("Connection error") - + with self.assertRaises(Exception) as context: server._invoke_multi_model_server_serving( request='{"input": "data"}', content_type="application/json", - accept="application/json" + accept="application/json", ) self.assertIn("Unable to send request", str(context.exception)) @@ -106,88 +103,97 @@ def test_invoke_multi_model_server_serving_failure(self, mock_get_host, mock_pos class TestSageMakerMultiModelServer(unittest.TestCase): """Test SageMakerMultiModelServer class.""" - @patch('sagemaker.serve.model_server.multi_model_server.server.S3Uploader') - @patch('sagemaker.serve.model_server.multi_model_server.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.multi_model_server.server.fw_utils') - @patch('sagemaker.serve.model_server.multi_model_server.server._is_s3_uri') - def test_upload_server_artifacts_with_s3_path(self, mock_is_s3, mock_fw_utils, mock_determine, mock_uploader): + @patch("sagemaker.serve.model_server.multi_model_server.server.S3Uploader") + @patch("sagemaker.serve.model_server.multi_model_server.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.multi_model_server.server.fw_utils") + @patch("sagemaker.serve.model_server.multi_model_server.server._is_s3_uri") + def test_upload_server_artifacts_with_s3_path( + self, mock_is_s3, mock_fw_utils, mock_determine, mock_uploader + ): """Test _upload_server_artifacts with S3 path.""" from sagemaker.serve.model_server.multi_model_server.server import SageMakerMultiModelServer - + server = SageMakerMultiModelServer() mock_is_s3.return_value = True mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + model_data, env_vars = server._upload_server_artifacts( model_path="s3://bucket/model", secret_key="test-key", sagemaker_session=mock_session, - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNotNone(model_data) self.assertEqual(model_data["S3DataSource"]["S3Uri"], "s3://bucket/model/") - @patch('sagemaker.serve.model_server.multi_model_server.server.S3Uploader') - @patch('sagemaker.serve.model_server.multi_model_server.server.s3_path_join') - @patch('sagemaker.serve.model_server.multi_model_server.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.multi_model_server.server.parse_s3_url') - @patch('sagemaker.serve.model_server.multi_model_server.server.fw_utils') - @patch('sagemaker.serve.model_server.multi_model_server.server._is_s3_uri') - @patch('sagemaker.serve.model_server.multi_model_server.server.Path') - def test_upload_server_artifacts_uploads_to_s3(self, mock_path, mock_is_s3, mock_fw_utils, - mock_parse, mock_determine, mock_s3_join, mock_uploader): + @patch("sagemaker.serve.model_server.multi_model_server.server.S3Uploader") + @patch("sagemaker.serve.model_server.multi_model_server.server.s3_path_join") + @patch("sagemaker.serve.model_server.multi_model_server.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.multi_model_server.server.parse_s3_url") + @patch("sagemaker.serve.model_server.multi_model_server.server.fw_utils") + @patch("sagemaker.serve.model_server.multi_model_server.server._is_s3_uri") + @patch("sagemaker.serve.model_server.multi_model_server.server.Path") + def test_upload_server_artifacts_uploads_to_s3( + self, + mock_path, + mock_is_s3, + mock_fw_utils, + mock_parse, + mock_determine, + mock_s3_join, + mock_uploader, + ): """Test _upload_server_artifacts uploads artifacts to S3.""" from sagemaker.serve.model_server.multi_model_server.server import SageMakerMultiModelServer - + server = SageMakerMultiModelServer() mock_is_s3.return_value = False mock_parse.return_value = ("bucket", "prefix") mock_determine.return_value = ("bucket", "code_prefix") mock_s3_join.return_value = "s3://bucket/code_prefix/code" mock_uploader.upload.return_value = "s3://bucket/code_prefix/code" - + mock_path_obj = Mock() mock_code_dir = Mock() mock_path_obj.joinpath.return_value = mock_code_dir mock_path.return_value = mock_path_obj - + mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + model_data, env_vars = server._upload_server_artifacts( model_path="/local/model", secret_key="test-key", sagemaker_session=mock_session, s3_model_data_url="s3://bucket/prefix", image="test-image", - should_upload_artifacts=True + should_upload_artifacts=True, ) - + self.assertIsNotNone(model_data) - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) - self.assertEqual(env_vars["SAGEMAKER_SERVE_SECRET_KEY"], "test-key") + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) - @patch('sagemaker.serve.model_server.multi_model_server.server._is_s3_uri') + @patch("sagemaker.serve.model_server.multi_model_server.server._is_s3_uri") def test_upload_server_artifacts_no_upload(self, mock_is_s3): """Test _upload_server_artifacts without uploading.""" from sagemaker.serve.model_server.multi_model_server.server import SageMakerMultiModelServer - + server = SageMakerMultiModelServer() mock_is_s3.return_value = False mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + model_data, env_vars = server._upload_server_artifacts( model_path="/local/model", secret_key="test-key", sagemaker_session=mock_session, - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNone(model_data) - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) class TestUpdateEnvVars(unittest.TestCase): @@ -196,17 +202,17 @@ class TestUpdateEnvVars(unittest.TestCase): def test_update_env_vars_with_none(self): """Test _update_env_vars with None input.""" from sagemaker.serve.model_server.multi_model_server.server import _update_env_vars - + result = _update_env_vars(None) self.assertIsInstance(result, dict) def test_update_env_vars_with_custom_vars(self): """Test _update_env_vars with custom variables.""" from sagemaker.serve.model_server.multi_model_server.server import _update_env_vars - + custom_vars = {"CUSTOM_KEY": "custom_value"} result = _update_env_vars(custom_vars) - + self.assertIn("CUSTOM_KEY", result) self.assertEqual(result["CUSTOM_KEY"], "custom_value") diff --git a/sagemaker-serve/tests/unit/model_server/test_smd_prepare.py b/sagemaker-serve/tests/unit/model_server/test_smd_prepare.py index 4d5a0a7de8..aa21763180 100644 --- a/sagemaker-serve/tests/unit/model_server/test_smd_prepare.py +++ b/sagemaker-serve/tests/unit/model_server/test_smd_prepare.py @@ -17,114 +17,102 @@ def tearDown(self): if Path(self.temp_dir).exists(): shutil.rmtree(self.temp_dir) - @patch('sagemaker.serve.model_server.smd.prepare.compute_hash') - @patch('sagemaker.serve.model_server.smd.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.smd.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_smd_with_inference_spec(self, mock_copy, mock_capture, mock_gen_key, mock_hash): + @patch("sagemaker.serve.model_server.smd.prepare.compute_hash") + @patch("sagemaker.serve.model_server.smd.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_smd_with_inference_spec(self, mock_copy, mock_capture, mock_hash): """Test prepare_for_smd with InferenceSpec.""" from sagemaker.serve.model_server.smd.prepare import prepare_for_smd from sagemaker.serve.spec.inference_spec import InferenceSpec - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - - mock_gen_key.return_value = "test-secret-key" + mock_hash.return_value = "test-hash" mock_inference_spec = Mock(spec=InferenceSpec) - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_smd( model_path=str(model_path), shared_libs=[], dependencies={}, - inference_spec=mock_inference_spec + inference_spec=mock_inference_spec, ) - - self.assertEqual(secret_key, "test-secret-key") + mock_inference_spec.prepare.assert_called_once_with(str(model_path)) - @patch('os.rename') - @patch('sagemaker.serve.model_server.smd.prepare.compute_hash') - @patch('sagemaker.serve.model_server.smd.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.smd.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_smd_with_custom_orchestrator(self, mock_copy, mock_capture, mock_gen_key, mock_hash, mock_rename): + @patch("os.rename") + @patch("sagemaker.serve.model_server.smd.prepare.compute_hash") + @patch("sagemaker.serve.model_server.smd.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_smd_with_custom_orchestrator( + self, mock_copy, mock_capture, mock_hash, mock_rename + ): """Test prepare_for_smd with CustomOrchestrator.""" from sagemaker.serve.model_server.smd.prepare import prepare_for_smd from sagemaker.serve.spec.inference_base import CustomOrchestrator - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - - mock_gen_key.return_value = "test-secret-key" + mock_hash.return_value = "test-hash" mock_orchestrator = Mock(spec=CustomOrchestrator) - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_smd( model_path=str(model_path), shared_libs=[], dependencies={}, - inference_spec=mock_orchestrator + inference_spec=mock_orchestrator, ) - - self.assertEqual(secret_key, "test-secret-key") + # Verify custom_execution_inference.py was copied and renamed mock_rename.assert_called_once() - @patch('sagemaker.serve.model_server.smd.prepare.compute_hash') - @patch('sagemaker.serve.model_server.smd.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.smd.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_smd_with_shared_libs(self, mock_copy, mock_capture, mock_gen_key, mock_hash): + @patch("sagemaker.serve.model_server.smd.prepare.compute_hash") + @patch("sagemaker.serve.model_server.smd.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_smd_with_shared_libs(self, mock_copy, mock_capture, mock_hash): """Test prepare_for_smd copies shared libraries.""" from sagemaker.serve.model_server.smd.prepare import prepare_for_smd - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + shared_lib = Path(self.temp_dir) / "lib.so" shared_lib.touch() - - mock_gen_key.return_value = "test-key" + mock_hash.return_value = "test-hash" - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): prepare_for_smd( - model_path=str(model_path), - shared_libs=[str(shared_lib)], - dependencies={} + model_path=str(model_path), shared_libs=[str(shared_lib)], dependencies={} ) - + # Verify copy2 was called for shared lib self.assertTrue(any(str(shared_lib) in str(call) for call in mock_copy.call_args_list)) def test_prepare_for_smd_invalid_dir(self): """Test prepare_for_smd raises exception for invalid directory.""" from sagemaker.serve.model_server.smd.prepare import prepare_for_smd - + file_path = Path(self.temp_dir) / "file.txt" file_path.touch() - + with self.assertRaises(Exception) as context: - prepare_for_smd( - model_path=str(file_path), - shared_libs=[], - dependencies={} - ) + prepare_for_smd(model_path=str(file_path), shared_libs=[], dependencies={}) self.assertIn("not a valid directory", str(context.exception)) diff --git a/sagemaker-serve/tests/unit/model_server/test_smd_server.py b/sagemaker-serve/tests/unit/model_server/test_smd_server.py index 8bf7d4424e..c88331219f 100644 --- a/sagemaker-serve/tests/unit/model_server/test_smd_server.py +++ b/sagemaker-serve/tests/unit/model_server/test_smd_server.py @@ -7,80 +7,78 @@ class TestSageMakerSmdServer(unittest.TestCase): """Test SageMakerSmdServer class.""" - @patch('sagemaker.serve.model_server.smd.server._is_s3_uri') + @patch("sagemaker.serve.model_server.smd.server._is_s3_uri") def test_upload_smd_artifacts_with_s3_path(self, mock_is_s3): """Test _upload_smd_artifacts with S3 path.""" from sagemaker.serve.model_server.smd.server import SageMakerSmdServer - + server = SageMakerSmdServer() mock_is_s3.return_value = True mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_smd_artifacts( model_path="s3://bucket/model", sagemaker_session=mock_session, secret_key="test-key", - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertEqual(s3_path, "s3://bucket/model") - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) - self.assertEqual(env_vars["SAGEMAKER_SERVE_SECRET_KEY"], "test-key") self.assertIn("SAGEMAKER_INFERENCE_CODE_DIRECTORY", env_vars) - @patch('sagemaker.serve.model_server.smd.server.upload') - @patch('sagemaker.serve.model_server.smd.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.smd.server.parse_s3_url') - @patch('sagemaker.serve.model_server.smd.server.fw_utils') - @patch('sagemaker.serve.model_server.smd.server._is_s3_uri') - def test_upload_smd_artifacts_uploads_to_s3(self, mock_is_s3, mock_fw_utils, - mock_parse, mock_determine, mock_upload): + @patch("sagemaker.serve.model_server.smd.server.upload") + @patch("sagemaker.serve.model_server.smd.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.smd.server.parse_s3_url") + @patch("sagemaker.serve.model_server.smd.server.fw_utils") + @patch("sagemaker.serve.model_server.smd.server._is_s3_uri") + def test_upload_smd_artifacts_uploads_to_s3( + self, mock_is_s3, mock_fw_utils, mock_parse, mock_determine, mock_upload + ): """Test _upload_smd_artifacts uploads to S3.""" from sagemaker.serve.model_server.smd.server import SageMakerSmdServer - + server = SageMakerSmdServer() mock_is_s3.return_value = False mock_parse.return_value = ("bucket", "prefix") mock_determine.return_value = ("bucket", "code_prefix") mock_upload.return_value = "s3://bucket/code_prefix/model.tar.gz" - + mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_smd_artifacts( model_path="/local/model", sagemaker_session=mock_session, secret_key="test-key", s3_model_data_url="s3://bucket/prefix", image="test-image", - should_upload_artifacts=True + should_upload_artifacts=True, ) - + self.assertEqual(s3_path, "s3://bucket/code_prefix/model.tar.gz") - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) self.assertIn("SAGEMAKER_INFERENCE_CODE", env_vars) mock_upload.assert_called_once() - @patch('sagemaker.serve.model_server.smd.server._is_s3_uri') + @patch("sagemaker.serve.model_server.smd.server._is_s3_uri") def test_upload_smd_artifacts_no_upload(self, mock_is_s3): """Test _upload_smd_artifacts without uploading.""" from sagemaker.serve.model_server.smd.server import SageMakerSmdServer - + server = SageMakerSmdServer() mock_is_s3.return_value = False mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_smd_artifacts( model_path="/local/model", sagemaker_session=mock_session, secret_key="test-key", - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNone(s3_path) - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + self.assertIn("SAGEMAKER_INFERENCE_CODE_DIRECTORY", env_vars) if __name__ == "__main__": diff --git a/sagemaker-serve/tests/unit/model_server/test_tei_server.py b/sagemaker-serve/tests/unit/model_server/test_tei_server.py index c280e4b546..4fff01710b 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tei_server.py +++ b/sagemaker-serve/tests/unit/model_server/test_tei_server.py @@ -8,99 +8,99 @@ class TestLocalTeiServing(unittest.TestCase): """Test LocalTeiServing class.""" - @patch('sagemaker.serve.model_server.tei.server._update_env_vars') - @patch('sagemaker.serve.model_server.tei.server.Path') - @patch('sagemaker.serve.model_server.tei.server.DeviceRequest') + @patch("sagemaker.serve.model_server.tei.server._update_env_vars") + @patch("sagemaker.serve.model_server.tei.server.Path") + @patch("sagemaker.serve.model_server.tei.server.DeviceRequest") def test_start_tei_serving(self, mock_device_req, mock_path, mock_update_env): """Test _start_tei_serving creates container.""" from sagemaker.serve.model_server.tei.server import LocalTeiServing - + server = LocalTeiServing() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value.joinpath.return_value = mock_path_obj mock_device_req.return_value = Mock() mock_update_env.return_value = {"HF_HOME": "/opt/ml/model/"} - + server._start_tei_serving( client=mock_client, image="tei:latest", model_path="/path/to/model", secret_key="test-secret", - env_vars={"CUSTOM_VAR": "value"} + env_vars={"CUSTOM_VAR": "value"}, ) - + self.assertEqual(server.container, mock_container) mock_client.containers.run.assert_called_once() - @patch('sagemaker.serve.model_server.tei.server._update_env_vars') - @patch('sagemaker.serve.model_server.tei.server.Path') - @patch('sagemaker.serve.model_server.tei.server.DeviceRequest') + @patch("sagemaker.serve.model_server.tei.server._update_env_vars") + @patch("sagemaker.serve.model_server.tei.server.Path") + @patch("sagemaker.serve.model_server.tei.server.DeviceRequest") def test_start_tei_serving_adds_secret_key(self, mock_device_req, mock_path, mock_update_env): - """Test _start_tei_serving adds secret key to env vars.""" + """Test _start_tei_serving no longer adds secret key to env vars.""" from sagemaker.serve.model_server.tei.server import LocalTeiServing - + server = LocalTeiServing() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value.joinpath.return_value = mock_path_obj mock_device_req.return_value = Mock() mock_update_env.return_value = {"HF_HOME": "/opt/ml/model/"} - + env_vars = {"CUSTOM_VAR": "value"} server._start_tei_serving( client=mock_client, image="tei:latest", model_path="/path/to/model", secret_key="test-secret", - env_vars=env_vars + env_vars=env_vars, ) - - # Verify secret key was added to env_vars - self.assertEqual(env_vars["SAGEMAKER_SERVE_SECRET_KEY"], "test-secret") - @patch('sagemaker.serve.model_server.tei.server.requests.post') - @patch('sagemaker.serve.model_server.tei.server.get_docker_host') + # Verify secret key is NOT added to env_vars + self.assertNotIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + + @patch("sagemaker.serve.model_server.tei.server.requests.post") + @patch("sagemaker.serve.model_server.tei.server.get_docker_host") def test_invoke_tei_serving_success(self, mock_get_host, mock_post): """Test _invoke_tei_serving successful request.""" from sagemaker.serve.model_server.tei.server import LocalTeiServing - + server = LocalTeiServing() mock_get_host.return_value = "localhost" mock_response = Mock() mock_response.content = b'{"embeddings": [[0.1, 0.2]]}' mock_post.return_value = mock_response - + result = server._invoke_tei_serving( request='{"inputs": "test text"}', content_type="application/json", - accept="application/json" + accept="application/json", ) - + self.assertEqual(result, b'{"embeddings": [[0.1, 0.2]]}') mock_post.assert_called_once() - @patch('sagemaker.serve.model_server.tei.server.requests.post') - @patch('sagemaker.serve.model_server.tei.server.get_docker_host') + @patch("sagemaker.serve.model_server.tei.server.requests.post") + @patch("sagemaker.serve.model_server.tei.server.get_docker_host") def test_invoke_tei_serving_failure(self, mock_get_host, mock_post): """Test _invoke_tei_serving handles errors.""" from sagemaker.serve.model_server.tei.server import LocalTeiServing - + server = LocalTeiServing() mock_get_host.return_value = "localhost" mock_post.side_effect = Exception("Connection error") - + with self.assertRaises(Exception) as context: server._invoke_tei_serving( request='{"inputs": "test"}', content_type="application/json", - accept="application/json" + accept="application/json", ) self.assertIn("Unable to send request", str(context.exception)) @@ -108,40 +108,48 @@ def test_invoke_tei_serving_failure(self, mock_get_host, mock_post): class TestSageMakerTeiServing(unittest.TestCase): """Test SageMakerTeiServing class.""" - @patch('sagemaker.serve.model_server.tei.server._update_env_vars') - @patch('sagemaker.serve.model_server.tei.server._is_s3_uri') + @patch("sagemaker.serve.model_server.tei.server._update_env_vars") + @patch("sagemaker.serve.model_server.tei.server._is_s3_uri") def test_upload_tei_artifacts_with_s3_path(self, mock_is_s3, mock_update_env): """Test _upload_tei_artifacts with S3 path.""" from sagemaker.serve.model_server.tei.server import SageMakerTeiServing - + server = SageMakerTeiServing() mock_is_s3.return_value = True mock_update_env.return_value = {"HF_HOME": "/opt/ml/model/"} mock_session = Mock() - + model_data, env_vars = server._upload_tei_artifacts( model_path="s3://bucket/model", sagemaker_session=mock_session, - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNotNone(model_data) self.assertEqual(model_data["S3DataSource"]["S3Uri"], "s3://bucket/model/") - @patch('sagemaker.serve.model_server.tei.server._update_env_vars') - @patch('sagemaker.serve.model_server.tei.server.S3Uploader') - @patch('sagemaker.serve.model_server.tei.server.s3_path_join') - @patch('sagemaker.serve.model_server.tei.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.tei.server.parse_s3_url') - @patch('sagemaker.serve.model_server.tei.server.fw_utils') - @patch('sagemaker.serve.model_server.tei.server._is_s3_uri') - @patch('sagemaker.serve.model_server.tei.server.Path') - def test_upload_tei_artifacts_uploads_to_s3(self, mock_path, mock_is_s3, mock_fw_utils, - mock_parse, mock_determine, mock_s3_join, - mock_uploader, mock_update_env): + @patch("sagemaker.serve.model_server.tei.server._update_env_vars") + @patch("sagemaker.serve.model_server.tei.server.S3Uploader") + @patch("sagemaker.serve.model_server.tei.server.s3_path_join") + @patch("sagemaker.serve.model_server.tei.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.tei.server.parse_s3_url") + @patch("sagemaker.serve.model_server.tei.server.fw_utils") + @patch("sagemaker.serve.model_server.tei.server._is_s3_uri") + @patch("sagemaker.serve.model_server.tei.server.Path") + def test_upload_tei_artifacts_uploads_to_s3( + self, + mock_path, + mock_is_s3, + mock_fw_utils, + mock_parse, + mock_determine, + mock_s3_join, + mock_uploader, + mock_update_env, + ): """Test _upload_tei_artifacts uploads to S3.""" from sagemaker.serve.model_server.tei.server import SageMakerTeiServing - + server = SageMakerTeiServing() mock_is_s3.return_value = False mock_parse.return_value = ("bucket", "prefix") @@ -149,43 +157,41 @@ def test_upload_tei_artifacts_uploads_to_s3(self, mock_path, mock_is_s3, mock_fw mock_s3_join.return_value = "s3://bucket/code_prefix/code" mock_uploader.upload.return_value = "s3://bucket/code_prefix/code" mock_update_env.return_value = {"HF_HOME": "/opt/ml/model/"} - + mock_path_obj = Mock() mock_code_dir = Mock() mock_path_obj.joinpath.return_value = mock_code_dir mock_path.return_value = mock_path_obj - + mock_session = Mock() - + model_data, env_vars = server._upload_tei_artifacts( model_path="/local/model", sagemaker_session=mock_session, s3_model_data_url="s3://bucket/prefix", image="test-image", env_vars={"CUSTOM": "var"}, - should_upload_artifacts=True + should_upload_artifacts=True, ) - + self.assertIsNotNone(model_data) mock_uploader.upload.assert_called_once() - @patch('sagemaker.serve.model_server.tei.server._update_env_vars') - @patch('sagemaker.serve.model_server.tei.server._is_s3_uri') + @patch("sagemaker.serve.model_server.tei.server._update_env_vars") + @patch("sagemaker.serve.model_server.tei.server._is_s3_uri") def test_upload_tei_artifacts_no_upload(self, mock_is_s3, mock_update_env): """Test _upload_tei_artifacts without uploading.""" from sagemaker.serve.model_server.tei.server import SageMakerTeiServing - + server = SageMakerTeiServing() mock_is_s3.return_value = False mock_update_env.return_value = {"HF_HOME": "/opt/ml/model/"} mock_session = Mock() - + model_data, env_vars = server._upload_tei_artifacts( - model_path="/local/model", - sagemaker_session=mock_session, - should_upload_artifacts=False + model_path="/local/model", sagemaker_session=mock_session, should_upload_artifacts=False ) - + self.assertIsNone(model_data) @@ -195,7 +201,7 @@ class TestUpdateEnvVars(unittest.TestCase): def test_update_env_vars_with_none(self): """Test _update_env_vars with None input.""" from sagemaker.serve.model_server.tei.server import _update_env_vars - + result = _update_env_vars(None) self.assertIn("HF_HOME", result) self.assertIn("HUGGINGFACE_HUB_CACHE", result) @@ -203,10 +209,10 @@ def test_update_env_vars_with_none(self): def test_update_env_vars_with_custom_vars(self): """Test _update_env_vars with custom variables.""" from sagemaker.serve.model_server.tei.server import _update_env_vars - + custom_vars = {"CUSTOM_KEY": "custom_value"} result = _update_env_vars(custom_vars) - + self.assertIn("CUSTOM_KEY", result) self.assertIn("HF_HOME", result) self.assertEqual(result["CUSTOM_KEY"], "custom_value") diff --git a/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_inference.py b/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_inference.py index 14aad247c3..cc6d7b967e 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_inference.py +++ b/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_inference.py @@ -16,6 +16,7 @@ class TestTensorFlowServingInference(unittest.TestCase): def test_input_handler_logic(self): """Test input_handler logic.""" + def input_handler(data, context, schema_builder): read_data = data.read() if hasattr(schema_builder, "custom_input_translator"): @@ -27,32 +28,33 @@ def input_handler(data, context, schema_builder): io.BytesIO(read_data), context.request_content_type ) return json.dumps({"instances": deserialized_data}) - + schema_builder = Mock() schema_builder.custom_input_translator = Mock() schema_builder.custom_input_translator.deserialize = Mock(return_value=[[1, 2, 3]]) - + mock_data = Mock() mock_data.read = Mock(return_value=b'{"data": [1, 2, 3]}') - + mock_context = Mock() mock_context.request_content_type = "application/json" - + result = input_handler(mock_data, mock_context, schema_builder) - + expected = json.dumps({"instances": [[1, 2, 3]]}) self.assertEqual(result, expected) def test_output_handler_logic(self): """Test output_handler logic.""" + def output_handler(data, context, schema_builder): if data.status_code != 200: raise ValueError(data.content.decode("utf-8")) - + response_content_type = context.accept_header prediction = data.content prediction_dict = json.loads(prediction.decode("utf-8")) - + if hasattr(schema_builder, "custom_output_translator"): return ( schema_builder.custom_output_translator.serialize( @@ -61,58 +63,66 @@ def output_handler(data, context, schema_builder): response_content_type, ) else: - return schema_builder.output_serializer.serialize(prediction_dict["predictions"]), response_content_type - + return ( + schema_builder.output_serializer.serialize(prediction_dict["predictions"]), + response_content_type, + ) + schema_builder = Mock() schema_builder.custom_output_translator = Mock() - schema_builder.custom_output_translator.serialize = Mock(return_value=b'{"predictions": [0.1, 0.9]}') - + schema_builder.custom_output_translator.serialize = Mock( + return_value=b'{"predictions": [0.1, 0.9]}' + ) + mock_data = Mock() mock_data.status_code = 200 - mock_data.content = json.dumps({"predictions": [0.1, 0.9]}).encode('utf-8') - + mock_data.content = json.dumps({"predictions": [0.1, 0.9]}).encode("utf-8") + mock_context = Mock() mock_context.accept_header = "application/json" - + result, content_type = output_handler(mock_data, mock_context, schema_builder) - + self.assertEqual(result, b'{"predictions": [0.1, 0.9]}') self.assertEqual(content_type, "application/json") def test_convert_numpy_array_logic(self): """Test conversion of numpy array.""" + def _convert_for_serialization(deserialized_data): if isinstance(deserialized_data, np.ndarray): return deserialized_data.tolist() return deserialized_data - + data = np.array([[1, 2, 3], [4, 5, 6]]) result = _convert_for_serialization(data) - + self.assertEqual(result, [[1, 2, 3], [4, 5, 6]]) def test_convert_pandas_dataframe_logic(self): """Test conversion of pandas DataFrame.""" + def _convert_for_serialization(deserialized_data): if isinstance(deserialized_data, pd.DataFrame): return deserialized_data.to_dict(orient="list") return deserialized_data - - data = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) + + data = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) result = _convert_for_serialization(data) - - self.assertEqual(result, {'a': [1, 2], 'b': [3, 4]}) + + self.assertEqual(result, {"a": [1, 2], "b": [3, 4]}) def test_convert_pandas_series_logic(self): """Test conversion of pandas Series.""" + def _convert_for_serialization(deserialized_data): if isinstance(deserialized_data, pd.Series): return deserialized_data.tolist() return deserialized_data - + data = pd.Series([1, 2, 3, 4]) result = _convert_for_serialization(data) - + self.assertEqual(result, [1, 2, 3, 4]) diff --git a/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_prepare.py b/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_prepare.py index e6ca1161dc..c78797be04 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_prepare.py +++ b/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_prepare.py @@ -17,122 +17,114 @@ def tearDown(self): if Path(self.temp_dir).exists(): shutil.rmtree(self.temp_dir) - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_tf_serving_success(self, mock_copy, mock_capture, mock_gen_key, - mock_hash, mock_get_saved, mock_move): + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents") + @patch( + "sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor" + ) + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_tf_serving_success( + self, mock_copy, mock_capture, mock_hash, mock_get_saved, mock_move + ): """Test prepare_for_tf_serving creates structure successfully.""" from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - - mock_gen_key.return_value = "test-secret-key" + mock_hash.return_value = "test-hash" mock_get_saved.return_value = Path(self.temp_dir) / "saved_model" - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_tf_serving( - model_path=str(model_path), - shared_libs=[], - dependencies={} + model_path=str(model_path), shared_libs=[], dependencies={} ) - - self.assertEqual(secret_key, "test-secret-key") + mock_capture.assert_called_once() mock_move.assert_called_once() - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_tf_serving_no_saved_model(self, mock_copy, mock_capture, mock_gen_key, - mock_hash, mock_get_saved): + @patch( + "sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor" + ) + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_tf_serving_no_saved_model( + self, mock_copy, mock_capture, mock_hash, mock_get_saved + ): """Test prepare_for_tf_serving raises error when SavedModel not found.""" from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - - mock_gen_key.return_value = "test-secret-key" + mock_hash.return_value = "test-hash" mock_get_saved.return_value = None - + with self.assertRaises(ValueError) as context: - prepare_for_tf_serving( - model_path=str(model_path), - shared_libs=[], - dependencies={} - ) + prepare_for_tf_serving(model_path=str(model_path), shared_libs=[], dependencies={}) self.assertIn("SavedModel is not found", str(context.exception)) - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_tf_serving_with_shared_libs(self, mock_copy, mock_capture, mock_gen_key, - mock_hash, mock_get_saved, mock_move): + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents") + @patch( + "sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor" + ) + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_tf_serving_with_shared_libs( + self, mock_copy, mock_capture, mock_hash, mock_get_saved, mock_move + ): """Test prepare_for_tf_serving copies shared libraries.""" from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + shared_lib = Path(self.temp_dir) / "lib.so" shared_lib.touch() - - mock_gen_key.return_value = "test-key" + mock_hash.return_value = "test-hash" mock_get_saved.return_value = Path(self.temp_dir) / "saved_model" - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): prepare_for_tf_serving( - model_path=str(model_path), - shared_libs=[str(shared_lib)], - dependencies={} + model_path=str(model_path), shared_libs=[str(shared_lib)], dependencies={} ) - + # Verify copy2 was called for shared lib self.assertTrue(any(str(shared_lib) in str(call) for call in mock_copy.call_args_list)) - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies') - @patch('shutil.copy2') - def test_prepare_for_tf_serving_invalid_dir(self, mock_copy, mock_capture, mock_gen_key, - mock_hash, mock_get_saved): + @patch( + "sagemaker.serve.model_server.tensorflow_serving.prepare._get_saved_model_path_for_tensorflow_and_keras_flavor" + ) + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies") + @patch("shutil.copy2") + def test_prepare_for_tf_serving_invalid_dir( + self, mock_copy, mock_capture, mock_hash, mock_get_saved + ): """Test prepare_for_tf_serving raises exception for invalid directory.""" from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving - + file_path = Path(self.temp_dir) / "file.txt" file_path.touch() - + with self.assertRaises(Exception) as context: - prepare_for_tf_serving( - model_path=str(file_path), - shared_libs=[], - dependencies={} - ) + prepare_for_tf_serving(model_path=str(file_path), shared_libs=[], dependencies={}) self.assertIn("not a valid directory", str(context.exception)) diff --git a/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_server.py b/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_server.py index d0bac2e5dc..4013b5c11c 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_server.py +++ b/sagemaker-serve/tests/unit/model_server/test_tensorflow_serving_server.py @@ -8,70 +8,68 @@ class TestLocalTensorflowServing(unittest.TestCase): """Test LocalTensorflowServing class.""" - @patch('sagemaker.serve.model_server.tensorflow_serving.server.Path') + @patch("sagemaker.serve.model_server.tensorflow_serving.server.Path") def test_start_tensorflow_serving(self, mock_path): """Test _start_tensorflow_serving creates container.""" from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing - + server = LocalTensorflowServing() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value = mock_path_obj - + server._start_tensorflow_serving( client=mock_client, image="tensorflow-serving:latest", model_path="/path/to/model", secret_key="test-secret", - env_vars={"CUSTOM_VAR": "value"} + env_vars={"CUSTOM_VAR": "value"}, ) - + self.assertEqual(server.container, mock_container) mock_client.containers.run.assert_called_once() call_kwargs = mock_client.containers.run.call_args[1] - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", call_kwargs["environment"]) - self.assertEqual(call_kwargs["environment"]["SAGEMAKER_SERVE_SECRET_KEY"], "test-secret") self.assertEqual(call_kwargs["environment"]["CUSTOM_VAR"], "value") - @patch('sagemaker.serve.model_server.tensorflow_serving.server.requests.post') - @patch('sagemaker.serve.model_server.tensorflow_serving.server.get_docker_host') + @patch("sagemaker.serve.model_server.tensorflow_serving.server.requests.post") + @patch("sagemaker.serve.model_server.tensorflow_serving.server.get_docker_host") def test_invoke_tensorflow_serving_success(self, mock_get_host, mock_post): """Test _invoke_tensorflow_serving successful request.""" from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing - + server = LocalTensorflowServing() mock_get_host.return_value = "localhost" mock_response = Mock() mock_response.content = b'{"predictions": [[0.1, 0.9]]}' mock_post.return_value = mock_response - + result = server._invoke_tensorflow_serving( request='{"instances": [[1, 2, 3]]}', content_type="application/json", - accept="application/json" + accept="application/json", ) - + self.assertEqual(result, b'{"predictions": [[0.1, 0.9]]}') mock_post.assert_called_once() - @patch('sagemaker.serve.model_server.tensorflow_serving.server.requests.post') - @patch('sagemaker.serve.model_server.tensorflow_serving.server.get_docker_host') + @patch("sagemaker.serve.model_server.tensorflow_serving.server.requests.post") + @patch("sagemaker.serve.model_server.tensorflow_serving.server.get_docker_host") def test_invoke_tensorflow_serving_failure(self, mock_get_host, mock_post): """Test _invoke_tensorflow_serving handles errors.""" from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing - + server = LocalTensorflowServing() mock_get_host.return_value = "localhost" mock_post.side_effect = Exception("Connection error") - + with self.assertRaises(Exception) as context: server._invoke_tensorflow_serving( request='{"instances": [[1, 2, 3]]}', content_type="application/json", - accept="application/json" + accept="application/json", ) self.assertIn("Unable to send request", str(context.exception)) @@ -79,78 +77,84 @@ def test_invoke_tensorflow_serving_failure(self, mock_get_host, mock_post): class TestSageMakerTensorflowServing(unittest.TestCase): """Test SageMakerTensorflowServing class.""" - @patch('sagemaker.serve.model_server.tensorflow_serving.server._is_s3_uri') + @patch("sagemaker.serve.model_server.tensorflow_serving.server._is_s3_uri") def test_upload_tensorflow_serving_artifacts_with_s3_path(self, mock_is_s3): """Test _upload_tensorflow_serving_artifacts with S3 path.""" - from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing - + from sagemaker.serve.model_server.tensorflow_serving.server import ( + SageMakerTensorflowServing, + ) + server = SageMakerTensorflowServing() mock_is_s3.return_value = True mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_tensorflow_serving_artifacts( model_path="s3://bucket/model", sagemaker_session=mock_session, secret_key="test-key", - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertEqual(s3_path, "s3://bucket/model") - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) - self.assertEqual(env_vars["SAGEMAKER_SERVE_SECRET_KEY"], "test-key") - - @patch('sagemaker.serve.model_server.tensorflow_serving.server.upload') - @patch('sagemaker.serve.model_server.tensorflow_serving.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.tensorflow_serving.server.parse_s3_url') - @patch('sagemaker.serve.model_server.tensorflow_serving.server.fw_utils') - @patch('sagemaker.serve.model_server.tensorflow_serving.server._is_s3_uri') - def test_upload_tensorflow_serving_artifacts_uploads_to_s3(self, mock_is_s3, mock_fw_utils, - mock_parse, mock_determine, mock_upload): + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) + + @patch("sagemaker.serve.model_server.tensorflow_serving.server.upload") + @patch("sagemaker.serve.model_server.tensorflow_serving.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.tensorflow_serving.server.parse_s3_url") + @patch("sagemaker.serve.model_server.tensorflow_serving.server.fw_utils") + @patch("sagemaker.serve.model_server.tensorflow_serving.server._is_s3_uri") + def test_upload_tensorflow_serving_artifacts_uploads_to_s3( + self, mock_is_s3, mock_fw_utils, mock_parse, mock_determine, mock_upload + ): """Test _upload_tensorflow_serving_artifacts uploads to S3.""" - from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing - + from sagemaker.serve.model_server.tensorflow_serving.server import ( + SageMakerTensorflowServing, + ) + server = SageMakerTensorflowServing() mock_is_s3.return_value = False mock_parse.return_value = ("bucket", "prefix") mock_determine.return_value = ("bucket", "code_prefix") mock_upload.return_value = "s3://bucket/code_prefix/model.tar.gz" - + mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_tensorflow_serving_artifacts( model_path="/local/model", sagemaker_session=mock_session, secret_key="test-key", s3_model_data_url="s3://bucket/prefix", image="test-image", - should_upload_artifacts=True + should_upload_artifacts=True, ) - + self.assertEqual(s3_path, "s3://bucket/code_prefix/model.tar.gz") - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) mock_upload.assert_called_once() - @patch('sagemaker.serve.model_server.tensorflow_serving.server._is_s3_uri') + @patch("sagemaker.serve.model_server.tensorflow_serving.server._is_s3_uri") def test_upload_tensorflow_serving_artifacts_no_upload(self, mock_is_s3): """Test _upload_tensorflow_serving_artifacts without uploading.""" - from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing - + from sagemaker.serve.model_server.tensorflow_serving.server import ( + SageMakerTensorflowServing, + ) + server = SageMakerTensorflowServing() mock_is_s3.return_value = False mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_tensorflow_serving_artifacts( model_path="/local/model", sagemaker_session=mock_session, secret_key="test-key", - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNone(s3_path) - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) if __name__ == "__main__": diff --git a/sagemaker-serve/tests/unit/model_server/test_tgi_prepare.py b/sagemaker-serve/tests/unit/model_server/test_tgi_prepare.py index 79417b7b74..992b83d2be 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tgi_prepare.py +++ b/sagemaker-serve/tests/unit/model_server/test_tgi_prepare.py @@ -18,187 +18,175 @@ def tearDown(self): if Path(self.temp_dir).exists(): shutil.rmtree(self.temp_dir) - @patch('tarfile.open') - @patch('sagemaker.serve.model_server.tgi.prepare.custom_extractall_tarfile') + @patch("tarfile.open") + @patch("sagemaker.serve.model_server.tgi.prepare.custom_extractall_tarfile") def test_extract_js_resource(self, mock_extract, mock_tarfile): """Test _extract_js_resource extracts tarball.""" from sagemaker.serve.model_server.tgi.prepare import _extract_js_resource - + js_model_dir = self.temp_dir code_dir = Path(self.temp_dir) / "code" code_dir.mkdir() - + # Create a dummy tar file tar_path = Path(js_model_dir) / "infer-prepack-test-id.tar.gz" tar_path.touch() - + mock_tar = Mock() mock_tarfile.return_value.__enter__.return_value = mock_tar - + _extract_js_resource(js_model_dir, code_dir, "test-id") - + mock_extract.assert_called_once_with(mock_tar, code_dir) - @patch('sagemaker.serve.model_server.tgi.prepare.S3Downloader') - @patch('sagemaker.serve.model_server.tgi.prepare._tmpdir') - @patch('sagemaker.serve.model_server.tgi.prepare._extract_js_resource') - def test_copy_jumpstart_artifacts_with_tarball(self, mock_extract, mock_tmpdir, mock_s3_downloader): + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") + @patch("sagemaker.serve.model_server.tgi.prepare._tmpdir") + @patch("sagemaker.serve.model_server.tgi.prepare._extract_js_resource") + def test_copy_jumpstart_artifacts_with_tarball( + self, mock_extract, mock_tmpdir, mock_s3_downloader + ): """Test _copy_jumpstart_artifacts with tar.gz file.""" from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts - + code_dir = Path(self.temp_dir) / "code" code_dir.mkdir() - + # Create config.json config_file = code_dir / "config.json" config_data = {"model_type": "gpt2"} config_file.write_text(json.dumps(config_data)) - + mock_tmpdir.return_value.__enter__.return_value = self.temp_dir mock_downloader_instance = Mock() mock_s3_downloader.return_value = mock_downloader_instance - + result = _copy_jumpstart_artifacts( - model_data="s3://bucket/model.tar.gz", - js_id="test-id", - code_dir=code_dir + model_data="s3://bucket/model.tar.gz", js_id="test-id", code_dir=code_dir ) - + self.assertEqual(result, (config_data, True)) mock_downloader_instance.download.assert_called_once() - @patch('sagemaker.serve.model_server.tgi.prepare.S3Downloader') + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") def test_copy_jumpstart_artifacts_uncompressed(self, mock_s3_downloader): """Test _copy_jumpstart_artifacts with uncompressed data.""" from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts - + code_dir = Path(self.temp_dir) / "code" code_dir.mkdir() - + config_file = code_dir / "config.json" config_data = {"model_type": "bert"} config_file.write_text(json.dumps(config_data)) - + mock_downloader_instance = Mock() mock_s3_downloader.return_value = mock_downloader_instance - + result = _copy_jumpstart_artifacts( - model_data="s3://bucket/model/", - js_id="test-id", - code_dir=code_dir + model_data="s3://bucket/model/", js_id="test-id", code_dir=code_dir ) - + self.assertEqual(result, (config_data, True)) - @patch('sagemaker.serve.model_server.tgi.prepare.S3Downloader') + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") def test_copy_jumpstart_artifacts_with_dict(self, mock_s3_downloader): """Test _copy_jumpstart_artifacts with dict model_data.""" from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts - + code_dir = Path(self.temp_dir) / "code" code_dir.mkdir() - + config_file = code_dir / "config.json" config_file.write_text(json.dumps({"model_type": "t5"})) - + mock_downloader_instance = Mock() mock_s3_downloader.return_value = mock_downloader_instance - - model_data = { - "S3DataSource": { - "S3Uri": "s3://bucket/model/" - } - } - + + model_data = {"S3DataSource": {"S3Uri": "s3://bucket/model/"}} + result = _copy_jumpstart_artifacts( - model_data=model_data, - js_id="test-id", - code_dir=code_dir + model_data=model_data, js_id="test-id", code_dir=code_dir ) - + self.assertIsNotNone(result) mock_downloader_instance.download.assert_called_once_with("s3://bucket/model/", code_dir) - @patch('sagemaker.serve.model_server.tgi.prepare.S3Downloader') + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") def test_copy_jumpstart_artifacts_invalid_format(self, mock_s3_downloader): """Test _copy_jumpstart_artifacts raises error for invalid format.""" from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts - + code_dir = Path(self.temp_dir) / "code" code_dir.mkdir() - + mock_downloader_instance = Mock() mock_s3_downloader.return_value = mock_downloader_instance - + with self.assertRaises(ValueError): _copy_jumpstart_artifacts( - model_data={"invalid": "format"}, - js_id="test-id", - code_dir=code_dir + model_data={"invalid": "format"}, js_id="test-id", code_dir=code_dir ) - @patch('sagemaker.serve.model_server.tgi.prepare.S3Downloader') + @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") def test_copy_jumpstart_artifacts_no_config(self, mock_s3_downloader): """Test _copy_jumpstart_artifacts when config.json doesn't exist.""" from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts - + code_dir = Path(self.temp_dir) / "code" code_dir.mkdir() - + mock_downloader_instance = Mock() mock_s3_downloader.return_value = mock_downloader_instance - + result = _copy_jumpstart_artifacts( - model_data="s3://bucket/model/", - js_id="test-id", - code_dir=code_dir + model_data="s3://bucket/model/", js_id="test-id", code_dir=code_dir ) - + self.assertEqual(result, (None, True)) - @patch('sagemaker.serve.model_server.tgi.prepare._check_docker_disk_usage') - @patch('sagemaker.serve.model_server.tgi.prepare._check_disk_space') + @patch("sagemaker.serve.model_server.tgi.prepare._check_docker_disk_usage") + @patch("sagemaker.serve.model_server.tgi.prepare._check_disk_space") def test_create_dir_structure(self, mock_disk_space, mock_docker_disk): """Test _create_dir_structure creates directories.""" from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure - + model_path = Path(self.temp_dir) / "model" model_path_obj, code_dir = _create_dir_structure(str(model_path)) - + self.assertTrue(model_path.exists()) self.assertTrue(code_dir.exists()) mock_disk_space.assert_called_once() mock_docker_disk.assert_called_once() - @patch('sagemaker.serve.model_server.tgi.prepare._check_docker_disk_usage') - @patch('sagemaker.serve.model_server.tgi.prepare._check_disk_space') + @patch("sagemaker.serve.model_server.tgi.prepare._check_docker_disk_usage") + @patch("sagemaker.serve.model_server.tgi.prepare._check_disk_space") def test_create_dir_structure_raises_on_file(self, mock_disk_space, mock_docker_disk): """Test _create_dir_structure raises ValueError for file path.""" from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure - + file_path = Path(self.temp_dir) / "file.txt" file_path.touch() - + with self.assertRaises(ValueError): _create_dir_structure(str(file_path)) - @patch('sagemaker.serve.model_server.tgi.prepare._copy_jumpstart_artifacts') - @patch('sagemaker.serve.model_server.tgi.prepare._create_dir_structure') + @patch("sagemaker.serve.model_server.tgi.prepare._copy_jumpstart_artifacts") + @patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure") def test_prepare_tgi_js_resources(self, mock_create_dir, mock_copy_js): """Test prepare_tgi_js_resources.""" from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources - + mock_model_path = Path(self.temp_dir) / "model" mock_code_dir = mock_model_path / "code" mock_create_dir.return_value = (mock_model_path, mock_code_dir) mock_copy_js.return_value = ({"config": "data"}, True) - + result = prepare_tgi_js_resources( model_path=str(mock_model_path), js_id="test-js-id", - model_data="s3://bucket/model.tar.gz" + model_data="s3://bucket/model.tar.gz", ) - + mock_create_dir.assert_called_once() mock_copy_js.assert_called_once() self.assertEqual(result, ({"config": "data"}, True)) diff --git a/sagemaker-serve/tests/unit/model_server/test_tgi_server.py b/sagemaker-serve/tests/unit/model_server/test_tgi_server.py index 12f84a747b..244ae6462c 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tgi_server.py +++ b/sagemaker-serve/tests/unit/model_server/test_tgi_server.py @@ -8,101 +8,99 @@ class TestLocalTgiServing(unittest.TestCase): """Test LocalTgiServing class.""" - @patch('sagemaker.serve.model_server.tgi.server.Path') - @patch('sagemaker.serve.model_server.tgi.server.DeviceRequest') + @patch("sagemaker.serve.model_server.tgi.server.Path") + @patch("sagemaker.serve.model_server.tgi.server.DeviceRequest") def test_start_tgi_serving_jumpstart(self, mock_device_req, mock_path): """Test _start_tgi_serving with jumpstart=True.""" from sagemaker.serve.model_server.tgi.server import LocalTgiServing - + server = LocalTgiServing() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value.joinpath.return_value = mock_path_obj mock_device_req.return_value = Mock() - + server._start_tgi_serving( client=mock_client, image="test-image:latest", model_path="/path/to/model", secret_key="test-secret", env_vars={"CUSTOM_VAR": "value"}, - jumpstart=True + jumpstart=True, ) - + self.assertEqual(server.container, mock_container) mock_client.containers.run.assert_called_once() call_args = mock_client.containers.run.call_args # Check that the command includes --model-id self.assertEqual(call_args[0][1][0], "--model-id") - @patch('sagemaker.serve.model_server.tgi.server._update_env_vars') - @patch('sagemaker.serve.model_server.tgi.server.Path') - @patch('sagemaker.serve.model_server.tgi.server.DeviceRequest') + @patch("sagemaker.serve.model_server.tgi.server._update_env_vars") + @patch("sagemaker.serve.model_server.tgi.server.Path") + @patch("sagemaker.serve.model_server.tgi.server.DeviceRequest") def test_start_tgi_serving_non_jumpstart(self, mock_device_req, mock_path, mock_update_env): """Test _start_tgi_serving with jumpstart=False.""" from sagemaker.serve.model_server.tgi.server import LocalTgiServing - + server = LocalTgiServing() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value.joinpath.return_value = mock_path_obj mock_device_req.return_value = Mock() mock_update_env.return_value = {"HF_HOME": "/opt/ml/model/"} - + server._start_tgi_serving( client=mock_client, image="test-image:latest", model_path="/path/to/model", secret_key="test-secret", env_vars={"CUSTOM_VAR": "value"}, - jumpstart=False + jumpstart=False, ) - + self.assertEqual(server.container, mock_container) mock_update_env.assert_called_once() - @patch('sagemaker.serve.model_server.tgi.server.requests.post') - @patch('sagemaker.serve.model_server.tgi.server.get_docker_host') + @patch("sagemaker.serve.model_server.tgi.server.requests.post") + @patch("sagemaker.serve.model_server.tgi.server.get_docker_host") def test_invoke_tgi_serving_success(self, mock_get_host, mock_post): """Test _invoke_tgi_serving successful request.""" from sagemaker.serve.model_server.tgi.server import LocalTgiServing - + server = LocalTgiServing() mock_get_host.return_value = "localhost" mock_response = Mock() mock_response.content = b'{"generated_text": "result"}' mock_post.return_value = mock_response - + result = server._invoke_tgi_serving( - request='{"inputs": "test"}', - content_type="application/json", - accept="application/json" + request='{"inputs": "test"}', content_type="application/json", accept="application/json" ) - + self.assertEqual(result, b'{"generated_text": "result"}') mock_post.assert_called_once() - @patch('sagemaker.serve.model_server.tgi.server.requests.post') - @patch('sagemaker.serve.model_server.tgi.server.get_docker_host') + @patch("sagemaker.serve.model_server.tgi.server.requests.post") + @patch("sagemaker.serve.model_server.tgi.server.get_docker_host") def test_invoke_tgi_serving_failure(self, mock_get_host, mock_post): """Test _invoke_tgi_serving handles errors.""" from sagemaker.serve.model_server.tgi.server import LocalTgiServing - + server = LocalTgiServing() mock_get_host.return_value = "localhost" mock_post.side_effect = Exception("Connection error") - + with self.assertRaises(Exception) as context: server._invoke_tgi_serving( request='{"inputs": "test"}', content_type="application/json", - accept="application/json" + accept="application/json", ) self.assertIn("Unable to send request", str(context.exception)) @@ -110,70 +108,78 @@ def test_invoke_tgi_serving_failure(self, mock_get_host, mock_post): class TestSageMakerTgiServing(unittest.TestCase): """Test SageMakerTgiServing class.""" - @patch('sagemaker.serve.model_server.tgi.server._is_s3_uri') + @patch("sagemaker.serve.model_server.tgi.server._is_s3_uri") def test_upload_tgi_artifacts_with_s3_path(self, mock_is_s3): """Test _upload_tgi_artifacts with S3 path.""" from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing - + server = SageMakerTgiServing() mock_is_s3.return_value = True mock_session = Mock() - + model_data, env_vars = server._upload_tgi_artifacts( model_path="s3://bucket/model", sagemaker_session=mock_session, jumpstart=False, - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNotNone(model_data) self.assertEqual(model_data["S3DataSource"]["S3Uri"], "s3://bucket/model/") - @patch('sagemaker.serve.model_server.tgi.server._is_s3_uri') + @patch("sagemaker.serve.model_server.tgi.server._is_s3_uri") def test_upload_tgi_artifacts_jumpstart(self, mock_is_s3): """Test _upload_tgi_artifacts with jumpstart=True.""" from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing - + server = SageMakerTgiServing() mock_is_s3.return_value = True mock_session = Mock() - + model_data, env_vars = server._upload_tgi_artifacts( model_path="s3://bucket/model", sagemaker_session=mock_session, jumpstart=True, - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNotNone(model_data) self.assertEqual(env_vars, {}) - @patch('sagemaker.serve.model_server.tgi.server.S3Uploader') - @patch('sagemaker.serve.model_server.tgi.server.s3_path_join') - @patch('sagemaker.serve.model_server.tgi.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.tgi.server.parse_s3_url') - @patch('sagemaker.serve.model_server.tgi.server.fw_utils') - @patch('sagemaker.serve.model_server.tgi.server._is_s3_uri') - @patch('sagemaker.serve.model_server.tgi.server.Path') - def test_upload_tgi_artifacts_uploads_to_s3(self, mock_path, mock_is_s3, mock_fw_utils, - mock_parse, mock_determine, mock_s3_join, mock_uploader): + @patch("sagemaker.serve.model_server.tgi.server.S3Uploader") + @patch("sagemaker.serve.model_server.tgi.server.s3_path_join") + @patch("sagemaker.serve.model_server.tgi.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.tgi.server.parse_s3_url") + @patch("sagemaker.serve.model_server.tgi.server.fw_utils") + @patch("sagemaker.serve.model_server.tgi.server._is_s3_uri") + @patch("sagemaker.serve.model_server.tgi.server.Path") + def test_upload_tgi_artifacts_uploads_to_s3( + self, + mock_path, + mock_is_s3, + mock_fw_utils, + mock_parse, + mock_determine, + mock_s3_join, + mock_uploader, + ): """Test _upload_tgi_artifacts uploads to S3.""" from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing - + server = SageMakerTgiServing() mock_is_s3.return_value = False mock_parse.return_value = ("bucket", "prefix") mock_determine.return_value = ("bucket", "code_prefix") mock_s3_join.return_value = "s3://bucket/code_prefix/code" mock_uploader.upload.return_value = "s3://bucket/code_prefix/code" - + mock_path_obj = Mock() mock_code_dir = Mock() mock_path_obj.joinpath.return_value = mock_code_dir mock_path.return_value = mock_path_obj - + mock_session = Mock() - + model_data, env_vars = server._upload_tgi_artifacts( model_path="/local/model", sagemaker_session=mock_session, @@ -181,9 +187,9 @@ def test_upload_tgi_artifacts_uploads_to_s3(self, mock_path, mock_is_s3, mock_fw s3_model_data_url="s3://bucket/prefix", image="test-image", env_vars={"CUSTOM": "var"}, - should_upload_artifacts=True + should_upload_artifacts=True, ) - + self.assertIsNotNone(model_data) mock_uploader.upload.assert_called_once() @@ -194,7 +200,7 @@ class TestUpdateEnvVars(unittest.TestCase): def test_update_env_vars_with_none(self): """Test _update_env_vars with None input.""" from sagemaker.serve.model_server.tgi.server import _update_env_vars - + result = _update_env_vars(None) self.assertIn("HF_HOME", result) self.assertIn("HUGGINGFACE_HUB_CACHE", result) @@ -202,10 +208,10 @@ def test_update_env_vars_with_none(self): def test_update_env_vars_with_custom_vars(self): """Test _update_env_vars with custom variables.""" from sagemaker.serve.model_server.tgi.server import _update_env_vars - + custom_vars = {"CUSTOM_KEY": "custom_value"} result = _update_env_vars(custom_vars) - + self.assertIn("CUSTOM_KEY", result) self.assertIn("HF_HOME", result) self.assertEqual(result["CUSTOM_KEY"], "custom_value") diff --git a/sagemaker-serve/tests/unit/model_server/test_tgi_utils.py b/sagemaker-serve/tests/unit/model_server/test_tgi_utils.py index 23042138e4..317836ef4c 100644 --- a/sagemaker-serve/tests/unit/model_server/test_tgi_utils.py +++ b/sagemaker-serve/tests/unit/model_server/test_tgi_utils.py @@ -1,4 +1,5 @@ """Unit tests for TGI serving utils module.""" + import unittest from unittest.mock import Mock, patch @@ -9,14 +10,14 @@ class TestTGIUtilsDataType(unittest.TestCase): def test_get_default_dtype(self): """Test _get_default_dtype returns bfloat16.""" from sagemaker.serve.model_server.tgi.utils import _get_default_dtype - + result = _get_default_dtype() self.assertEqual(result, "bfloat16") def test_get_admissible_dtypes(self): """Test _get_admissible_dtypes returns list with bfloat16.""" from sagemaker.serve.model_server.tgi.utils import _get_admissible_dtypes - + result = _get_admissible_dtypes() self.assertEqual(result, ["bfloat16"]) @@ -24,97 +25,87 @@ def test_get_admissible_dtypes(self): class TestTGIUtilsConfigurations(unittest.TestCase): """Test TGI utils configuration functions.""" - @patch('sagemaker.serve.model_server.tgi.utils._get_default_max_tokens') - @patch('sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree') + @patch("sagemaker.serve.model_server.tgi.utils._get_default_max_tokens") + @patch("sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree") def test_get_default_tgi_configurations_with_sharding(self, mock_parallel, mock_tokens): """Test TGI configurations with sharding enabled.""" from sagemaker.serve.model_server.tgi.utils import _get_default_tgi_configurations - + mock_parallel.return_value = 4 mock_tokens.return_value = (2048, 512) - + mock_schema_builder = Mock() mock_schema_builder.sample_input = {"inputs": "test"} mock_schema_builder.sample_output = [{"generated_text": "output"}] - + env, max_new_tokens = _get_default_tgi_configurations( - "model-id", - {"num_attention_heads": 32}, - mock_schema_builder + "model-id", {"num_attention_heads": 32}, mock_schema_builder ) - + self.assertEqual(env["SHARDED"], "true") self.assertEqual(env["NUM_SHARD"], "4") self.assertEqual(env["DTYPE"], "bfloat16") self.assertEqual(max_new_tokens, 512) - @patch('sagemaker.serve.model_server.tgi.utils._get_default_max_tokens') - @patch('sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree') + @patch("sagemaker.serve.model_server.tgi.utils._get_default_max_tokens") + @patch("sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree") def test_get_default_tgi_configurations_without_sharding(self, mock_parallel, mock_tokens): """Test TGI configurations with sharding disabled.""" from sagemaker.serve.model_server.tgi.utils import _get_default_tgi_configurations - + mock_parallel.return_value = 1 mock_tokens.return_value = (1024, 256) - + mock_schema_builder = Mock() mock_schema_builder.sample_input = {"inputs": "test"} mock_schema_builder.sample_output = [{"generated_text": "output"}] - + env, max_new_tokens = _get_default_tgi_configurations( - "model-id", - {"num_attention_heads": 12}, - mock_schema_builder + "model-id", {"num_attention_heads": 12}, mock_schema_builder ) - + self.assertEqual(env["SHARDED"], "false") self.assertEqual(env["NUM_SHARD"], "1") self.assertEqual(env["DTYPE"], "bfloat16") self.assertEqual(max_new_tokens, 256) - @patch('sagemaker.serve.model_server.tgi.utils._get_default_max_tokens') - @patch('sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree') + @patch("sagemaker.serve.model_server.tgi.utils._get_default_max_tokens") + @patch("sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree") def test_get_default_tgi_configurations_no_parallel_degree(self, mock_parallel, mock_tokens): """Test TGI configurations when parallel degree is None.""" from sagemaker.serve.model_server.tgi.utils import _get_default_tgi_configurations - + mock_parallel.return_value = None mock_tokens.return_value = (1024, 256) - + mock_schema_builder = Mock() mock_schema_builder.sample_input = {"inputs": "test"} mock_schema_builder.sample_output = [{"generated_text": "output"}] - - env, max_new_tokens = _get_default_tgi_configurations( - "model-id", - {}, - mock_schema_builder - ) - + + env, max_new_tokens = _get_default_tgi_configurations("model-id", {}, mock_schema_builder) + self.assertIsNone(env["SHARDED"]) self.assertIsNone(env["NUM_SHARD"]) self.assertEqual(env["DTYPE"], "bfloat16") self.assertEqual(max_new_tokens, 256) - @patch('sagemaker.serve.model_server.tgi.utils._get_default_max_tokens') - @patch('sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree') + @patch("sagemaker.serve.model_server.tgi.utils._get_default_max_tokens") + @patch("sagemaker.serve.model_server.tgi.utils._get_default_tensor_parallel_degree") def test_get_default_tgi_configurations_returns_tuple(self, mock_parallel, mock_tokens): """Test that function returns a tuple.""" from sagemaker.serve.model_server.tgi.utils import _get_default_tgi_configurations - + mock_parallel.return_value = 2 mock_tokens.return_value = (1024, 256) - + mock_schema_builder = Mock() mock_schema_builder.sample_input = {"inputs": "test"} mock_schema_builder.sample_output = [{"generated_text": "output"}] - + result = _get_default_tgi_configurations( - "model-id", - {"num_attention_heads": 16}, - mock_schema_builder + "model-id", {"num_attention_heads": 16}, mock_schema_builder ) - + self.assertIsInstance(result, tuple) self.assertEqual(len(result), 2) self.assertIsInstance(result[0], dict) diff --git a/sagemaker-serve/tests/unit/model_server/test_torchserve_inference.py b/sagemaker-serve/tests/unit/model_server/test_torchserve_inference.py index 9d5fc57485..e3081f8b94 100644 --- a/sagemaker-serve/tests/unit/model_server/test_torchserve_inference.py +++ b/sagemaker-serve/tests/unit/model_server/test_torchserve_inference.py @@ -14,93 +14,111 @@ class TestTorchServeInference(unittest.TestCase): def test_predict_fn_logic(self): """Test predict_fn logic.""" + def predict_fn(input_data, predict_callable): return predict_callable(input_data) - + mock_predict_callable = Mock(return_value=[0.1, 0.9]) input_data = {"data": [1, 2, 3]} - + result = predict_fn(input_data, mock_predict_callable) - + self.assertEqual(result, [0.1, 0.9]) mock_predict_callable.assert_called_once_with(input_data) def test_input_fn_with_preprocess_logic(self): """Test input_fn with preprocess logic.""" + def input_fn(input_data, content_type, schema_builder, inference_spec): # Deserialize if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data.encode("utf-8")) if isinstance(input_data, str) else io.BytesIO(input_data), + ( + io.BytesIO(input_data.encode("utf-8")) + if isinstance(input_data, str) + else io.BytesIO(input_data) + ), content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data.encode("utf-8")) if isinstance(input_data, str) else io.BytesIO(input_data), + ( + io.BytesIO(input_data.encode("utf-8")) + if isinstance(input_data, str) + else io.BytesIO(input_data) + ), content_type, ) - + # Preprocess if available if hasattr(inference_spec, "preprocess"): preprocessed = inference_spec.preprocess(deserialized_data) if preprocessed is not None: return preprocessed - + return deserialized_data - + schema_builder = Mock() schema_builder.custom_input_translator = Mock() schema_builder.custom_input_translator.deserialize = Mock(return_value={"data": [1, 2, 3]}) - + inference_spec = Mock() inference_spec.preprocess = Mock(return_value={"preprocessed": True}) - + result = input_fn('{"data": [1, 2, 3]}', "application/json", schema_builder, inference_spec) - + self.assertEqual(result, {"preprocessed": True}) inference_spec.preprocess.assert_called_once_with({"data": [1, 2, 3]}) def test_output_fn_with_postprocess_logic(self): """Test output_fn with postprocess logic.""" + def output_fn(predictions, accept_type, schema_builder, inference_spec): # Postprocess if available if hasattr(inference_spec, "postprocess"): postprocessed = inference_spec.postprocess(predictions) if postprocessed is not None: predictions = postprocessed - + # Serialize if hasattr(schema_builder, "custom_output_translator"): return schema_builder.custom_output_translator.serialize(predictions, accept_type) else: return schema_builder.output_serializer.serialize(predictions) - + schema_builder = Mock() schema_builder.custom_output_translator = Mock() - schema_builder.custom_output_translator.serialize = Mock(return_value=b'{"predictions": [0.1, 0.9]}') - + schema_builder.custom_output_translator.serialize = Mock( + return_value=b'{"predictions": [0.1, 0.9]}' + ) + inference_spec = Mock() inference_spec.postprocess = Mock(return_value={"postprocessed": True}) - + result = output_fn([0.1, 0.9], "application/json", schema_builder, inference_spec) - + inference_spec.postprocess.assert_called_once_with([0.1, 0.9]) - schema_builder.custom_output_translator.serialize.assert_called_once_with({"postprocessed": True}, "application/json") + schema_builder.custom_output_translator.serialize.assert_called_once_with( + {"postprocessed": True}, "application/json" + ) - @patch.dict(os.environ, {'MLFLOW_MODEL_FLAVOR': 'pytorch'}) + @patch.dict(os.environ, {"MLFLOW_MODEL_FLAVOR": "pytorch"}) def test_get_mlflow_flavor_logic(self): """Test _get_mlflow_flavor logic.""" + def _get_mlflow_flavor(): return os.getenv("MLFLOW_MODEL_FLAVOR") - + result = _get_mlflow_flavor() - self.assertEqual(result, 'pytorch') + self.assertEqual(result, "pytorch") - @patch('importlib.import_module') + @patch("importlib.import_module") def test_load_mlflow_model_logic(self, mock_import): """Test _load_mlflow_model logic.""" + def _load_mlflow_model(deployment_flavor, model_dir): import importlib + flavor_loader_map = { "pytorch": ("mlflow.pytorch", "load_model"), "tensorflow": ("mlflow.tensorflow", "load_model"), @@ -111,14 +129,14 @@ def _load_mlflow_model(deployment_flavor, model_dir): flavor_module = importlib.import_module(flavor_module_name) load_model_function = getattr(flavor_module, load_function_name) return load_model_function(model_dir) - + mock_module = Mock() mock_module.load_model = Mock(return_value=Mock()) mock_import.return_value = mock_module - - result = _load_mlflow_model('tensorflow', '/model/dir') - - mock_import.assert_called_once_with('mlflow.tensorflow') + + result = _load_mlflow_model("tensorflow", "/model/dir") + + mock_import.assert_called_once_with("mlflow.tensorflow") if __name__ == "__main__": diff --git a/sagemaker-serve/tests/unit/model_server/test_torchserve_prepare.py b/sagemaker-serve/tests/unit/model_server/test_torchserve_prepare.py index 1ae35eca6a..d1ca6decde 100644 --- a/sagemaker-serve/tests/unit/model_server/test_torchserve_prepare.py +++ b/sagemaker-serve/tests/unit/model_server/test_torchserve_prepare.py @@ -17,170 +17,162 @@ def tearDown(self): if Path(self.temp_dir).exists(): shutil.rmtree(self.temp_dir) - @patch('sagemaker.serve.model_server.torchserve.prepare.compute_hash') - @patch('sagemaker.serve.model_server.torchserve.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.torchserve.prepare.capture_dependencies') - @patch('sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri') - @patch('shutil.copy2') - def test_prepare_for_torchserve_standard_image(self, mock_copy, mock_is_1p, mock_capture, - mock_gen_key, mock_hash): + @patch("sagemaker.serve.model_server.torchserve.prepare.compute_hash") + @patch("sagemaker.serve.model_server.torchserve.prepare.capture_dependencies") + @patch("sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri") + @patch("shutil.copy2") + def test_prepare_for_torchserve_standard_image( + self, mock_copy, mock_is_1p, mock_capture, mock_hash + ): """Test prepare_for_torchserve with standard image.""" from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + mock_is_1p.return_value = True - mock_gen_key.return_value = "test-secret-key" mock_hash.return_value = "test-hash" mock_session = Mock() mock_inference_spec = Mock() - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_torchserve( model_path=str(model_path), shared_libs=[], dependencies={}, session=mock_session, image_uri="test-pytorch-image", - inference_spec=mock_inference_spec + inference_spec=mock_inference_spec, ) - - self.assertEqual(secret_key, "test-secret-key") + mock_inference_spec.prepare.assert_called_once_with(str(model_path)) mock_capture.assert_called_once() - @patch('os.rename') - @patch('sagemaker.serve.model_server.torchserve.prepare.compute_hash') - @patch('sagemaker.serve.model_server.torchserve.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.torchserve.prepare.capture_dependencies') - @patch('sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri') - @patch('shutil.copy2') - def test_prepare_for_torchserve_xgboost_image(self, mock_copy, mock_is_1p, mock_capture, - mock_gen_key, mock_hash, mock_rename): + @patch("os.rename") + @patch("sagemaker.serve.model_server.torchserve.prepare.compute_hash") + @patch("sagemaker.serve.model_server.torchserve.prepare.capture_dependencies") + @patch("sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri") + @patch("shutil.copy2") + def test_prepare_for_torchserve_xgboost_image( + self, mock_copy, mock_is_1p, mock_capture, mock_hash, mock_rename + ): """Test prepare_for_torchserve with xgboost image.""" from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + mock_is_1p.return_value = True - mock_gen_key.return_value = "test-secret-key" mock_hash.return_value = "test-hash" mock_session = Mock() - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_torchserve( model_path=str(model_path), shared_libs=[], dependencies={}, session=mock_session, image_uri="xgboost-image:latest", - inference_spec=None + inference_spec=None, ) - - self.assertEqual(secret_key, "test-secret-key") + # Verify xgboost_inference.py was copied and renamed mock_rename.assert_called_once() - @patch('sagemaker.serve.model_server.torchserve.prepare.compute_hash') - @patch('sagemaker.serve.model_server.torchserve.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.torchserve.prepare.capture_dependencies') - @patch('sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri') - @patch('shutil.copy2') - def test_prepare_for_torchserve_with_shared_libs(self, mock_copy, mock_is_1p, mock_capture, - mock_gen_key, mock_hash): + @patch("sagemaker.serve.model_server.torchserve.prepare.compute_hash") + @patch("sagemaker.serve.model_server.torchserve.prepare.capture_dependencies") + @patch("sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri") + @patch("shutil.copy2") + def test_prepare_for_torchserve_with_shared_libs( + self, mock_copy, mock_is_1p, mock_capture, mock_hash + ): """Test prepare_for_torchserve copies shared libraries.""" from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + shared_lib = Path(self.temp_dir) / "lib.so" shared_lib.touch() - + mock_is_1p.return_value = False - mock_gen_key.return_value = "test-key" mock_hash.return_value = "test-hash" mock_session = Mock() - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): prepare_for_torchserve( model_path=str(model_path), shared_libs=[str(shared_lib)], dependencies={}, session=mock_session, - image_uri="test-image" + image_uri="test-image", ) - + # Verify copy2 was called for shared lib self.assertTrue(any(str(shared_lib) in str(call) for call in mock_copy.call_args_list)) - @patch('sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri') + @patch("sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri") def test_prepare_for_torchserve_invalid_dir(self, mock_is_1p): """Test prepare_for_torchserve raises exception for invalid directory.""" from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve - + file_path = Path(self.temp_dir) / "file.txt" file_path.touch() - + mock_session = Mock() - + with self.assertRaises(Exception) as context: prepare_for_torchserve( model_path=str(file_path), shared_libs=[], dependencies={}, session=mock_session, - image_uri="test-image" + image_uri="test-image", ) self.assertIn("not a valid directory", str(context.exception)) - @patch('sagemaker.serve.model_server.torchserve.prepare.compute_hash') - @patch('sagemaker.serve.model_server.torchserve.prepare.generate_secret_key') - @patch('sagemaker.serve.model_server.torchserve.prepare.capture_dependencies') - @patch('sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri') - @patch('shutil.copy2') - def test_prepare_for_torchserve_no_inference_spec(self, mock_copy, mock_is_1p, mock_capture, - mock_gen_key, mock_hash): + @patch("sagemaker.serve.model_server.torchserve.prepare.compute_hash") + @patch("sagemaker.serve.model_server.torchserve.prepare.capture_dependencies") + @patch("sagemaker.serve.model_server.torchserve.prepare.is_1p_image_uri") + @patch("shutil.copy2") + def test_prepare_for_torchserve_no_inference_spec( + self, mock_copy, mock_is_1p, mock_capture, mock_hash + ): """Test prepare_for_torchserve without inference_spec.""" from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve - + model_path = Path(self.temp_dir) / "model" code_dir = model_path / "code" code_dir.mkdir(parents=True) - + serve_pkl = code_dir / "serve.pkl" serve_pkl.write_bytes(b"test data") - + mock_is_1p.return_value = False - mock_gen_key.return_value = "test-key" mock_hash.return_value = "test-hash" mock_session = Mock() - - with patch('builtins.open', mock_open(read_data=b"test data")): + + with patch("builtins.open", mock_open(read_data=b"test data")): secret_key = prepare_for_torchserve( model_path=str(model_path), shared_libs=[], dependencies={}, session=mock_session, image_uri="test-image", - inference_spec=None + inference_spec=None, ) - - self.assertEqual(secret_key, "test-key") if __name__ == "__main__": diff --git a/sagemaker-serve/tests/unit/model_server/test_torchserve_server.py b/sagemaker-serve/tests/unit/model_server/test_torchserve_server.py index 95b0645076..ccc4368841 100644 --- a/sagemaker-serve/tests/unit/model_server/test_torchserve_server.py +++ b/sagemaker-serve/tests/unit/model_server/test_torchserve_server.py @@ -8,70 +8,68 @@ class TestLocalTorchServe(unittest.TestCase): """Test LocalTorchServe class.""" - @patch('sagemaker.serve.model_server.torchserve.server.Path') + @patch("sagemaker.serve.model_server.torchserve.server.Path") def test_start_torch_serve(self, mock_path): """Test _start_torch_serve creates container.""" from sagemaker.serve.model_server.torchserve.server import LocalTorchServe - + server = LocalTorchServe() mock_client = Mock() mock_container = Mock() mock_client.containers.run.return_value = mock_container - + mock_path_obj = Mock() mock_path.return_value = mock_path_obj - + server._start_torch_serve( client=mock_client, image="torchserve:latest", model_path="/path/to/model", secret_key="test-secret", - env_vars={"CUSTOM_VAR": "value"} + env_vars={"CUSTOM_VAR": "value"}, ) - + self.assertEqual(server.container, mock_container) mock_client.containers.run.assert_called_once() call_kwargs = mock_client.containers.run.call_args[1] - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", call_kwargs["environment"]) - self.assertEqual(call_kwargs["environment"]["SAGEMAKER_SERVE_SECRET_KEY"], "test-secret") self.assertEqual(call_kwargs["environment"]["CUSTOM_VAR"], "value") - @patch('sagemaker.serve.model_server.torchserve.server.requests.post') - @patch('sagemaker.serve.model_server.torchserve.server.get_docker_host') + @patch("sagemaker.serve.model_server.torchserve.server.requests.post") + @patch("sagemaker.serve.model_server.torchserve.server.get_docker_host") def test_invoke_torch_serve_success(self, mock_get_host, mock_post): """Test _invoke_torch_serve successful request.""" from sagemaker.serve.model_server.torchserve.server import LocalTorchServe - + server = LocalTorchServe() mock_get_host.return_value = "localhost" mock_response = Mock() mock_response.content = b'{"predictions": [0.1, 0.9]}' mock_post.return_value = mock_response - + result = server._invoke_torch_serve( request='{"data": [1, 2, 3]}', content_type="application/json", - accept="application/json" + accept="application/json", ) - + self.assertEqual(result, b'{"predictions": [0.1, 0.9]}') mock_post.assert_called_once() - @patch('sagemaker.serve.model_server.torchserve.server.requests.post') - @patch('sagemaker.serve.model_server.torchserve.server.get_docker_host') + @patch("sagemaker.serve.model_server.torchserve.server.requests.post") + @patch("sagemaker.serve.model_server.torchserve.server.get_docker_host") def test_invoke_torch_serve_failure(self, mock_get_host, mock_post): """Test _invoke_torch_serve handles errors.""" from sagemaker.serve.model_server.torchserve.server import LocalTorchServe - + server = LocalTorchServe() mock_get_host.return_value = "localhost" mock_post.side_effect = Exception("Connection error") - + with self.assertRaises(Exception) as context: server._invoke_torch_serve( request='{"data": [1, 2, 3]}', content_type="application/json", - accept="application/json" + accept="application/json", ) self.assertIn("Unable to send request", str(context.exception)) @@ -79,78 +77,78 @@ def test_invoke_torch_serve_failure(self, mock_get_host, mock_post): class TestSageMakerTorchServe(unittest.TestCase): """Test SageMakerTorchServe class.""" - @patch('sagemaker.serve.model_server.torchserve.server._is_s3_uri') + @patch("sagemaker.serve.model_server.torchserve.server._is_s3_uri") def test_upload_torchserve_artifacts_with_s3_path(self, mock_is_s3): """Test _upload_torchserve_artifacts with S3 path.""" from sagemaker.serve.model_server.torchserve.server import SageMakerTorchServe - + server = SageMakerTorchServe() mock_is_s3.return_value = True mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_torchserve_artifacts( model_path="s3://bucket/model", sagemaker_session=mock_session, secret_key="test-key", - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertEqual(s3_path, "s3://bucket/model") - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) - self.assertEqual(env_vars["SAGEMAKER_SERVE_SECRET_KEY"], "test-key") - - @patch('sagemaker.serve.model_server.torchserve.server.upload') - @patch('sagemaker.serve.model_server.torchserve.server.determine_bucket_and_prefix') - @patch('sagemaker.serve.model_server.torchserve.server.parse_s3_url') - @patch('sagemaker.serve.model_server.torchserve.server.fw_utils') - @patch('sagemaker.serve.model_server.torchserve.server._is_s3_uri') - def test_upload_torchserve_artifacts_uploads_to_s3(self, mock_is_s3, mock_fw_utils, - mock_parse, mock_determine, mock_upload): + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) + + @patch("sagemaker.serve.model_server.torchserve.server.upload") + @patch("sagemaker.serve.model_server.torchserve.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.torchserve.server.parse_s3_url") + @patch("sagemaker.serve.model_server.torchserve.server.fw_utils") + @patch("sagemaker.serve.model_server.torchserve.server._is_s3_uri") + def test_upload_torchserve_artifacts_uploads_to_s3( + self, mock_is_s3, mock_fw_utils, mock_parse, mock_determine, mock_upload + ): """Test _upload_torchserve_artifacts uploads to S3.""" from sagemaker.serve.model_server.torchserve.server import SageMakerTorchServe - + server = SageMakerTorchServe() mock_is_s3.return_value = False mock_parse.return_value = ("bucket", "prefix") mock_determine.return_value = ("bucket", "code_prefix") mock_upload.return_value = "s3://bucket/code_prefix/model.tar.gz" - + mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_torchserve_artifacts( model_path="/local/model", sagemaker_session=mock_session, secret_key="test-key", s3_model_data_url="s3://bucket/prefix", image="test-image", - should_upload_artifacts=True + should_upload_artifacts=True, ) - + self.assertEqual(s3_path, "s3://bucket/code_prefix/model.tar.gz") - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) mock_upload.assert_called_once() - @patch('sagemaker.serve.model_server.torchserve.server._is_s3_uri') + @patch("sagemaker.serve.model_server.torchserve.server._is_s3_uri") def test_upload_torchserve_artifacts_no_upload(self, mock_is_s3): """Test _upload_torchserve_artifacts without uploading.""" from sagemaker.serve.model_server.torchserve.server import SageMakerTorchServe - + server = SageMakerTorchServe() mock_is_s3.return_value = False mock_session = Mock() mock_session.boto_region_name = "us-west-2" - + s3_path, env_vars = server._upload_torchserve_artifacts( model_path="/local/model", sagemaker_session=mock_session, secret_key="test-key", - should_upload_artifacts=False + should_upload_artifacts=False, ) - + self.assertIsNone(s3_path) - self.assertIn("SAGEMAKER_SERVE_SECRET_KEY", env_vars) + self.assertIn("SAGEMAKER_SUBMIT_DIRECTORY", env_vars) if __name__ == "__main__": diff --git a/sagemaker-serve/tests/unit/model_server/test_torchserve_xgboost_inference.py b/sagemaker-serve/tests/unit/model_server/test_torchserve_xgboost_inference.py index 8d065707f6..9b2f81febd 100644 --- a/sagemaker-serve/tests/unit/model_server/test_torchserve_xgboost_inference.py +++ b/sagemaker-serve/tests/unit/model_server/test_torchserve_xgboost_inference.py @@ -13,43 +13,48 @@ class TestXGBoostInferenceSimple(unittest.TestCase): def test_predict_fn_logic(self): """Test predict_fn logic.""" + # Simulate the predict_fn behavior def predict_fn(input_data, predict_callable): return predict_callable(input_data) - + mock_predict_callable = Mock(return_value=[0.1, 0.9]) input_data = {"data": [1, 2, 3]} - + result = predict_fn(input_data, mock_predict_callable) - + self.assertEqual(result, [0.1, 0.9]) mock_predict_callable.assert_called_once_with(input_data) - @patch.dict(os.environ, {'MLFLOW_MODEL_FLAVOR': 'sklearn'}) + @patch.dict(os.environ, {"MLFLOW_MODEL_FLAVOR": "sklearn"}) def test_get_mlflow_flavor_logic(self): """Test _get_mlflow_flavor logic.""" + # Simulate the _get_mlflow_flavor behavior def _get_mlflow_flavor(): return os.getenv("MLFLOW_MODEL_FLAVOR") - + result = _get_mlflow_flavor() - self.assertEqual(result, 'sklearn') + self.assertEqual(result, "sklearn") @patch.dict(os.environ, {}, clear=True) def test_get_mlflow_flavor_none_logic(self): """Test _get_mlflow_flavor with no env var.""" + def _get_mlflow_flavor(): return os.getenv("MLFLOW_MODEL_FLAVOR") - + result = _get_mlflow_flavor() self.assertIsNone(result) - @patch('importlib.import_module') + @patch("importlib.import_module") def test_load_mlflow_model_logic(self, mock_import): """Test _load_mlflow_model logic.""" + # Simulate the _load_mlflow_model behavior def _load_mlflow_model(deployment_flavor, model_dir): import importlib + flavor_loader_map = { "sklearn": ("mlflow.sklearn", "load_model"), "pytorch": ("mlflow.pytorch", "load_model"), @@ -60,70 +65,81 @@ def _load_mlflow_model(deployment_flavor, model_dir): flavor_module = importlib.import_module(flavor_module_name) load_model_function = getattr(flavor_module, load_function_name) return load_model_function(model_dir) - + mock_module = Mock() mock_module.load_model = Mock(return_value=Mock()) mock_import.return_value = mock_module - - result = _load_mlflow_model('sklearn', '/model/dir') - - mock_import.assert_called_once_with('mlflow.sklearn') + + result = _load_mlflow_model("sklearn", "/model/dir") + + mock_import.assert_called_once_with("mlflow.sklearn") def test_input_fn_custom_translator_logic(self): """Test input_fn with custom translator logic.""" import io - + # Simulate input_fn behavior def input_fn(input_data, content_type, schema_builder): if hasattr(schema_builder, "custom_input_translator"): return schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data.encode("utf-8")) if isinstance(input_data, str) else io.BytesIO(input_data), + ( + io.BytesIO(input_data.encode("utf-8")) + if isinstance(input_data, str) + else io.BytesIO(input_data) + ), content_type, ) else: return schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data.encode("utf-8")) if isinstance(input_data, str) else io.BytesIO(input_data), + ( + io.BytesIO(input_data.encode("utf-8")) + if isinstance(input_data, str) + else io.BytesIO(input_data) + ), content_type[0], ) - + schema_builder = Mock() schema_builder.custom_input_translator = Mock() schema_builder.custom_input_translator.deserialize = Mock(return_value={"data": [1, 2, 3]}) - + result = input_fn('{"data": [1, 2, 3]}', ["application/json"], schema_builder) - + self.assertEqual(result, {"data": [1, 2, 3]}) def test_output_fn_custom_translator_logic(self): """Test output_fn with custom translator logic.""" + # Simulate output_fn behavior def output_fn(predictions, accept_type, schema_builder): if hasattr(schema_builder, "custom_output_translator"): return schema_builder.custom_output_translator.serialize(predictions, accept_type) else: return schema_builder.output_serializer.serialize(predictions) - + schema_builder = Mock() schema_builder.custom_output_translator = Mock() - schema_builder.custom_output_translator.serialize = Mock(return_value=b'{"predictions": [0.1, 0.9]}') - + schema_builder.custom_output_translator.serialize = Mock( + return_value=b'{"predictions": [0.1, 0.9]}' + ) + result = output_fn([0.1, 0.9], "application/json", schema_builder) - + self.assertEqual(result, b'{"predictions": [0.1, 0.9]}') def test_python_version_check_logic(self): """Test Python version parity check logic.""" import platform - + # Simulate _py_vs_parity_check behavior def _py_vs_parity_check(local_py_vs): container_py_vs = platform.python_version() if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]: return False # Would log warning return True - + # Test matching versions - result = _py_vs_parity_check('3.9.0') + result = _py_vs_parity_check("3.9.0") # Result depends on actual Python version, just verify it runs self.assertIsInstance(result, bool) diff --git a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py index b15e77a0b0..4355474c3d 100644 --- a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py +++ b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py @@ -8,15 +8,16 @@ import unittest # Prevent JumpStart from loading region config during import -os.environ['SAGEMAKER_INTERNAL_SKIP_REGION_CONFIG'] = '1' +os.environ["SAGEMAKER_INTERNAL_SKIP_REGION_CONFIG"] = "1" from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.model_builder_servers import _ModelBuilderServers + class MockModelBuilderServers(_ModelBuilderServers): """Mock class that inherits _ModelBuilderServers behavior.""" - + def __init__(self): self.model_server = ModelServer.TORCHSERVE self.model = None @@ -46,75 +47,75 @@ def __init__(self): self.framework_version = None self._is_mlflow_model = False self.config_name = None - + def _deploy_local_endpoint(self, **kwargs): return Mock() - + def _deploy_core_endpoint(self, *args, **kwargs): return Mock() - + def _save_model_inference_spec(self): pass - + def _is_jumpstart_model_id(self): return False - + def _auto_detect_image_uri(self): pass - + def _prepare_for_mode(self, should_upload_artifacts=False): return ("s3://bucket/model.tar.gz", None) - + def _create_model(self): return Mock() - + def _validate_tgi_serving_sample_data(self): pass - + def _validate_djl_serving_sample_data(self): pass - + def _validate_for_triton(self): pass - + def _auto_detect_image_for_triton(self): pass - + def _save_inference_spec(self): pass - + def _prepare_for_triton(self): pass - + def get_huggingface_model_metadata(self, model_id, token=None): return {} - + def _normalize_framework_to_enum(self, framework): return framework - + def _get_processing_unit(self): return "cpu" - + def _get_smd_image_uri(self, processing_unit): return "smd-image-uri" - + def _create_conda_env(self): pass class TestBuildForModelServer(unittest.TestCase): """Test _build_for_model_server method.""" - + def setUp(self): self.builder = MockModelBuilderServers() - + def test_unsupported_model_server(self): """Test error for unsupported model server.""" self.builder.model_server = "INVALID_SERVER" with self.assertRaises(ValueError) as ctx: self.builder._build_for_model_server() self.assertIn("not supported", str(ctx.exception)) - + def test_missing_required_parameters(self): """Test error when model, MLflow path, and inference_spec are all missing.""" self.builder.model = None @@ -123,8 +124,8 @@ def test_missing_required_parameters(self): with self.assertRaises(ValueError) as ctx: self.builder._build_for_model_server() self.assertIn("Missing required parameter", str(ctx.exception)) - - @patch.object(MockModelBuilderServers, '_build_for_torchserve') + + @patch.object(MockModelBuilderServers, "_build_for_torchserve") def test_route_to_torchserve(self, mock_build): """Test routing to TorchServe builder.""" self.builder.model_server = ModelServer.TORCHSERVE @@ -132,8 +133,8 @@ def test_route_to_torchserve(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_triton') + + @patch.object(MockModelBuilderServers, "_build_for_triton") def test_route_to_triton(self, mock_build): """Test routing to Triton builder.""" self.builder.model_server = ModelServer.TRITON @@ -141,8 +142,8 @@ def test_route_to_triton(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_tensorflow_serving') + + @patch.object(MockModelBuilderServers, "_build_for_tensorflow_serving") def test_route_to_tensorflow_serving(self, mock_build): """Test routing to TensorFlow Serving builder.""" self.builder.model_server = ModelServer.TENSORFLOW_SERVING @@ -150,8 +151,8 @@ def test_route_to_tensorflow_serving(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_djl') + + @patch.object(MockModelBuilderServers, "_build_for_djl") def test_route_to_djl(self, mock_build): """Test routing to DJL builder.""" self.builder.model_server = ModelServer.DJL_SERVING @@ -159,8 +160,8 @@ def test_route_to_djl(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_tei') + + @patch.object(MockModelBuilderServers, "_build_for_tei") def test_route_to_tei(self, mock_build): """Test routing to TEI builder.""" self.builder.model_server = ModelServer.TEI @@ -168,8 +169,8 @@ def test_route_to_tei(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_tgi') + + @patch.object(MockModelBuilderServers, "_build_for_tgi") def test_route_to_tgi(self, mock_build): """Test routing to TGI builder.""" self.builder.model_server = ModelServer.TGI @@ -177,8 +178,8 @@ def test_route_to_tgi(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_transformers') + + @patch.object(MockModelBuilderServers, "_build_for_transformers") def test_route_to_mms(self, mock_build): """Test routing to MMS builder.""" self.builder.model_server = ModelServer.MMS @@ -186,8 +187,8 @@ def test_route_to_mms(self, mock_build): mock_build.return_value = Mock() self.builder._build_for_model_server() mock_build.assert_called_once() - - @patch.object(MockModelBuilderServers, '_build_for_smd') + + @patch.object(MockModelBuilderServers, "_build_for_smd") def test_route_to_smd(self, mock_build): """Test routing to SMD builder.""" self.builder.model_server = ModelServer.SMD @@ -199,107 +200,124 @@ def test_route_to_smd(self, mock_build): class TestBuildForTorchServe(unittest.TestCase): """Test _build_for_torchserve method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.TORCHSERVE - - @patch.object(MockModelBuilderServers, '_save_model_inference_spec') - @patch.object(MockModelBuilderServers, '_is_jumpstart_model_id') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_hf_model_id(self, mock_create, mock_prepare, mock_detect, mock_js, mock_save): + + @patch.object(MockModelBuilderServers, "_save_model_inference_spec") + @patch.object(MockModelBuilderServers, "_is_jumpstart_model_id") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_hf_model_id( + self, mock_create, mock_prepare, mock_detect, mock_js, mock_save + ): """Test building with HuggingFace model ID.""" mock_js.return_value = False mock_create.return_value = Mock() self.builder.mode = Mode.IN_PROCESS self.builder.model = "bert-base-uncased" self.builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "test-token"} - + result = self.builder._build_for_torchserve() - + self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "bert-base-uncased") self.assertEqual(self.builder.env_vars["HF_TOKEN"], "test-token") self.assertIsNone(self.builder.s3_upload_path) mock_save.assert_called_once() mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers.prepare_for_torchserve') - @patch.object(MockModelBuilderServers, '_save_model_inference_spec') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_local_container_mode(self, mock_create, mock_prepare, mock_detect, mock_save, mock_ts_prepare): + + @patch("sagemaker.serve.model_builder_servers.prepare_for_torchserve") + @patch.object(MockModelBuilderServers, "_save_model_inference_spec") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_local_container_mode( + self, mock_create, mock_prepare, mock_detect, mock_save, mock_ts_prepare + ): """Test building for LOCAL_CONTAINER mode.""" self.builder.mode = Mode.LOCAL_CONTAINER self.builder.model = Mock() - mock_ts_prepare.return_value = "secret123" + mock_ts_prepare.return_value = "" mock_create.return_value = Mock() - + result = self.builder._build_for_torchserve() - + mock_ts_prepare.assert_called_once() - self.assertEqual(self.builder.secret_key, "secret123") + self.assertEqual(self.builder.secret_key, "") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers.prepare_for_torchserve') - @patch.object(MockModelBuilderServers, '_save_model_inference_spec') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_sagemaker_endpoint_mode(self, mock_create, mock_prepare, mock_detect, mock_save, mock_ts_prepare): + + @patch("sagemaker.serve.model_builder_servers.prepare_for_torchserve") + @patch.object(MockModelBuilderServers, "_save_model_inference_spec") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_sagemaker_endpoint_mode( + self, mock_create, mock_prepare, mock_detect, mock_save, mock_ts_prepare + ): """Test building for SAGEMAKER_ENDPOINT mode.""" self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.model = Mock() - mock_ts_prepare.return_value = "secret456" + mock_ts_prepare.return_value = "" mock_create.return_value = Mock() mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) - + result = self.builder._build_for_torchserve() - + mock_ts_prepare.assert_called_once() - self.assertEqual(self.builder.secret_key, "secret456") + self.assertEqual(self.builder.secret_key, "") mock_prepare.assert_called_with(should_upload_artifacts=True) class TestBuildForTGI(unittest.TestCase): """Test _build_for_tgi method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.TGI - - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.tgi.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_tgi_serving_sample_data') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_notebook_instance(self, mock_create, mock_prepare, mock_detect, - mock_validate, mock_dir, mock_nb): + + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_tgi_serving_sample_data") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_notebook_instance( + self, mock_create, mock_prepare, mock_detect, mock_validate, mock_dir, mock_nb + ): """Test building with notebook instance detection.""" mock_nb.return_value = "ml.g4dn.xlarge" mock_create.return_value = Mock() mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) self.builder.model = Mock() - + result = self.builder._build_for_tgi() - + self.assertEqual(self.builder.instance_type, "ml.g4dn.xlarge") mock_create.assert_called_once() - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_default_tgi_configurations') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.tgi.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_tgi_serving_sample_data') - @patch.object(MockModelBuilderServers, '_is_jumpstart_model_id') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_hf_model(self, mock_create, mock_prepare, mock_detect, mock_js, - mock_validate, mock_dir, mock_nb, mock_tgi_config, mock_hf_config): + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_tgi_serving_sample_data") + @patch.object(MockModelBuilderServers, "_is_jumpstart_model_id") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_hf_model( + self, + mock_create, + mock_prepare, + mock_detect, + mock_js, + mock_validate, + mock_dir, + mock_nb, + mock_tgi_config, + mock_hf_config, + ): """Test building with HuggingFace model.""" mock_js.return_value = False mock_nb.return_value = None @@ -310,25 +328,34 @@ def test_build_with_hf_model(self, mock_create, mock_prepare, mock_detect, mock_ self.builder.model = "gpt2" self.builder.mode = Mode.LOCAL_CONTAINER self.builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "token"} - + result = self.builder._build_for_tgi() - + self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "gpt2") self.assertEqual(self.builder.env_vars["HF_TOKEN"], "token") self.assertEqual(self.builder.env_vars["SHARDED"], "false") self.assertEqual(self.builder.env_vars["NUM_SHARD"], "1") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_gpu_info') - @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.tgi.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_tgi_serving_sample_data') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_sagemaker_endpoint_with_gpu(self, mock_create, mock_prepare, mock_detect, - mock_validate, mock_dir, mock_nb, mock_tp, mock_gpu): + + @patch("sagemaker.serve.model_builder_servers._get_gpu_info") + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_tgi_serving_sample_data") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_sagemaker_endpoint_with_gpu( + self, + mock_create, + mock_prepare, + mock_detect, + mock_validate, + mock_dir, + mock_nb, + mock_tp, + mock_gpu, + ): """Test building for SAGEMAKER_ENDPOINT with GPU sharding.""" mock_nb.return_value = None mock_gpu.return_value = 4 @@ -338,24 +365,34 @@ def test_build_sagemaker_endpoint_with_gpu(self, mock_create, mock_prepare, mock self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.model = Mock() self.builder.hf_model_config = {"model_type": "gpt2"} - + result = self.builder._build_for_tgi() - + self.assertEqual(self.builder.env_vars["NUM_SHARD"], "2") self.assertEqual(self.builder.env_vars["SHARDED"], "true") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_gpu_info_fallback') - @patch('sagemaker.serve.model_builder_servers._get_gpu_info') - @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_tgi_serving_sample_data') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_gpu_fallback(self, mock_create, mock_prepare, mock_detect, mock_validate, - mock_dir, mock_nb, mock_tp, mock_gpu, mock_fallback): + + @patch("sagemaker.serve.model_builder_servers._get_gpu_info_fallback") + @patch("sagemaker.serve.model_builder_servers._get_gpu_info") + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_tgi_serving_sample_data") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_gpu_fallback( + self, + mock_create, + mock_prepare, + mock_detect, + mock_validate, + mock_dir, + mock_nb, + mock_tp, + mock_gpu, + mock_fallback, + ): """Test GPU info fallback when primary method fails.""" mock_nb.return_value = None mock_gpu.side_effect = Exception("GPU info failed") @@ -365,28 +402,29 @@ def test_build_gpu_fallback(self, mock_create, mock_prepare, mock_detect, mock_v mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.model = Mock() - + result = self.builder._build_for_tgi() - + mock_fallback.assert_called_once() mock_create.assert_called_once() class TestBuildForDJL(unittest.TestCase): """Test _build_for_djl method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.DJL_SERVING - - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_djl_serving_sample_data') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_timeout(self, mock_create, mock_prepare, mock_detect, - mock_validate, mock_dir, mock_nb): + + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_djl_serving_sample_data") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_timeout( + self, mock_create, mock_prepare, mock_detect, mock_validate, mock_dir, mock_nb + ): """Test building with model_data_download_timeout.""" mock_nb.return_value = None mock_create.return_value = Mock() @@ -394,23 +432,33 @@ def test_build_with_timeout(self, mock_create, mock_prepare, mock_detect, self.builder.model = Mock() self.builder.mode = Mode.LOCAL_CONTAINER self.builder.model_data_download_timeout = 600 - + result = self.builder._build_for_djl() - + self.assertEqual(self.builder.env_vars["MODEL_LOADING_TIMEOUT"], "600") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_djl_serving_sample_data') - @patch.object(MockModelBuilderServers, '_is_jumpstart_model_id') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_hf_model(self, mock_create, mock_prepare, mock_detect, mock_js, - mock_validate, mock_dir, mock_nb, mock_djl_config, mock_hf_config): + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_djl_serving_sample_data") + @patch.object(MockModelBuilderServers, "_is_jumpstart_model_id") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_hf_model( + self, + mock_create, + mock_prepare, + mock_detect, + mock_js, + mock_validate, + mock_dir, + mock_nb, + mock_djl_config, + mock_hf_config, + ): """Test building with HuggingFace model.""" mock_js.return_value = False mock_nb.return_value = None @@ -421,24 +469,33 @@ def test_build_with_hf_model(self, mock_create, mock_prepare, mock_detect, mock_ self.builder.model = "gpt2" self.builder.mode = Mode.LOCAL_CONTAINER self.builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "token"} - + result = self.builder._build_for_djl() - + self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "gpt2") self.assertEqual(self.builder.env_vars["HF_TOKEN"], "token") self.assertEqual(self.builder.env_vars["OPTION_ENGINE"], "Python") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_gpu_info') - @patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_validate_djl_serving_sample_data') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_sagemaker_endpoint_tensor_parallel(self, mock_create, mock_prepare, mock_detect, - mock_validate, mock_dir, mock_nb, mock_tp, mock_gpu): + + @patch("sagemaker.serve.model_builder_servers._get_gpu_info") + @patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_validate_djl_serving_sample_data") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_sagemaker_endpoint_tensor_parallel( + self, + mock_create, + mock_prepare, + mock_detect, + mock_validate, + mock_dir, + mock_nb, + mock_tp, + mock_gpu, + ): """Test building for SAGEMAKER_ENDPOINT with tensor parallelism.""" mock_nb.return_value = None mock_gpu.return_value = 4 @@ -448,29 +505,37 @@ def test_build_sagemaker_endpoint_tensor_parallel(self, mock_create, mock_prepar self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.model = Mock() self.builder.hf_model_config = {"model_type": "gpt2"} - + result = self.builder._build_for_djl() - + self.assertEqual(self.builder.env_vars["TENSOR_PARALLEL_DEGREE"], "4") mock_create.assert_called_once() class TestBuildForTriton(unittest.TestCase): """Test _build_for_triton method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.TRITON - - @patch.object(MockModelBuilderServers, 'get_huggingface_model_metadata') - @patch.object(MockModelBuilderServers, '_validate_for_triton') - @patch.object(MockModelBuilderServers, '_is_jumpstart_model_id') - @patch.object(MockModelBuilderServers, '_save_inference_spec') - @patch.object(MockModelBuilderServers, '_prepare_for_triton') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_hf_model_string(self, mock_create, mock_prepare_mode, mock_prepare_triton, - mock_save, mock_js, mock_validate, mock_hf_meta): + + @patch.object(MockModelBuilderServers, "get_huggingface_model_metadata") + @patch.object(MockModelBuilderServers, "_validate_for_triton") + @patch.object(MockModelBuilderServers, "_is_jumpstart_model_id") + @patch.object(MockModelBuilderServers, "_save_inference_spec") + @patch.object(MockModelBuilderServers, "_prepare_for_triton") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_hf_model_string( + self, + mock_create, + mock_prepare_mode, + mock_prepare_triton, + mock_save, + mock_js, + mock_validate, + mock_hf_meta, + ): """Test building with HuggingFace model string.""" mock_js.return_value = False mock_hf_meta.return_value = {"pipeline_tag": "text-generation"} @@ -478,26 +543,35 @@ def test_build_with_hf_model_string(self, mock_create, mock_prepare_mode, mock_p mock_prepare_mode.return_value = ("s3://bucket/model.tar.gz", None) self.builder.model = "gpt2" self.builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "token"} - + result = self.builder._build_for_triton() - + self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "gpt2") self.assertEqual(self.builder.env_vars["HF_TASK"], "text-generation") self.assertEqual(self.builder.env_vars["HF_TOKEN"], "token") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._detect_framework_and_version') - @patch('sagemaker.serve.model_builder_servers._get_model_base') - @patch.object(MockModelBuilderServers, '_normalize_framework_to_enum') - @patch.object(MockModelBuilderServers, '_validate_for_triton') - @patch.object(MockModelBuilderServers, '_auto_detect_image_for_triton') - @patch.object(MockModelBuilderServers, '_save_inference_spec') - @patch.object(MockModelBuilderServers, '_prepare_for_triton') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_model_object(self, mock_create, mock_prepare_mode, mock_prepare_triton, - mock_save, mock_detect_img, mock_validate, mock_normalize, - mock_base, mock_detect_fw): + + @patch("sagemaker.serve.model_builder_servers._detect_framework_and_version") + @patch("sagemaker.serve.model_builder_servers._get_model_base") + @patch.object(MockModelBuilderServers, "_normalize_framework_to_enum") + @patch.object(MockModelBuilderServers, "_validate_for_triton") + @patch.object(MockModelBuilderServers, "_auto_detect_image_for_triton") + @patch.object(MockModelBuilderServers, "_save_inference_spec") + @patch.object(MockModelBuilderServers, "_prepare_for_triton") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_model_object( + self, + mock_create, + mock_prepare_mode, + mock_prepare_triton, + mock_save, + mock_detect_img, + mock_validate, + mock_normalize, + mock_base, + mock_detect_fw, + ): """Test building with model object.""" mock_base.return_value = "pytorch_model" mock_detect_fw.return_value = ("pytorch", "1.8.0") @@ -506,9 +580,9 @@ def test_build_with_model_object(self, mock_create, mock_prepare_mode, mock_prep mock_prepare_mode.return_value = ("s3://bucket/model.tar.gz", None) self.builder.model = Mock() self.builder.image_uri = None - + result = self.builder._build_for_triton() - + self.assertEqual(self.builder.framework_version, "1.8.0") mock_detect_img.assert_called_once() mock_create.assert_called_once() @@ -516,40 +590,40 @@ def test_build_with_model_object(self, mock_create, mock_prepare_mode, mock_prep class TestBuildForTensorFlowServing(unittest.TestCase): """Test _build_for_tensorflow_serving method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.TENSORFLOW_SERVING self.builder._is_mlflow_model = True - - @patch('sagemaker.serve.model_builder_servers.save_pkl') - @patch('sagemaker.serve.model_builder_servers.prepare_for_tf_serving') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') + + @patch("sagemaker.serve.model_builder_servers.save_pkl") + @patch("sagemaker.serve.model_builder_servers.prepare_for_tf_serving") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") def test_build_mlflow_model(self, mock_create, mock_prepare_mode, mock_tf_prepare, mock_save): """Test building MLflow model for TensorFlow Serving.""" - mock_tf_prepare.return_value = "secret789" + mock_tf_prepare.return_value = "" mock_create.return_value = Mock() mock_prepare_mode.return_value = ("s3://bucket/model.tar.gz", None) - + result = self.builder._build_for_tensorflow_serving() - - self.assertEqual(self.builder.secret_key, "secret789") + + self.assertEqual(self.builder.secret_key, "") mock_save.assert_called_once() mock_create.assert_called_once() - + def test_build_non_mlflow_model_error(self): """Test error when building non-MLflow model.""" self.builder._is_mlflow_model = False - + with self.assertRaises(ValueError) as ctx: self.builder._build_for_tensorflow_serving() self.assertIn("mlflow", str(ctx.exception).lower()) - + def test_build_missing_image_uri_error(self): """Test error when image_uri is missing.""" self.builder.image_uri = None - + with self.assertRaises(ValueError) as ctx: self.builder._build_for_tensorflow_serving() self.assertIn("image_uri", str(ctx.exception)) @@ -557,20 +631,21 @@ def test_build_missing_image_uri_error(self): class TestBuildForTEI(unittest.TestCase): """Test _build_for_tei method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.TEI - - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.tgi.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_is_jumpstart_model_id') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_hf_model(self, mock_create, mock_prepare, mock_detect, mock_js, - mock_dir, mock_nb, mock_hf_config): + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_is_jumpstart_model_id") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_hf_model( + self, mock_create, mock_prepare, mock_detect, mock_js, mock_dir, mock_nb, mock_hf_config + ): """Test building with HuggingFace model.""" mock_js.return_value = False mock_nb.return_value = None @@ -579,27 +654,28 @@ def test_build_with_hf_model(self, mock_create, mock_prepare, mock_detect, mock_ mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) self.builder.model = "bert-base-uncased" self.builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "token"} - + result = self.builder._build_for_tei() - + self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "bert-base-uncased") self.assertEqual(self.builder.env_vars["HF_TOKEN"], "token") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.tgi.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_sagemaker_endpoint_missing_instance_type(self, mock_create, mock_prepare, - mock_detect, mock_dir, mock_nb): + + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_sagemaker_endpoint_missing_instance_type( + self, mock_create, mock_prepare, mock_detect, mock_dir, mock_nb + ): """Test error when instance_type is missing for SAGEMAKER_ENDPOINT.""" mock_nb.return_value = None self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.instance_type = None self.builder.model = Mock() mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) - + with self.assertRaises(ValueError) as ctx: self.builder._build_for_tei() self.assertIn("Instance type", str(ctx.exception)) @@ -607,76 +683,92 @@ def test_build_sagemaker_endpoint_missing_instance_type(self, mock_create, mock_ class TestBuildForSMD(unittest.TestCase): """Test _build_for_smd method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.SMD - - @patch('sagemaker.serve.model_builder_servers.prepare_for_smd') - @patch.object(MockModelBuilderServers, '_save_model_inference_spec') - @patch.object(MockModelBuilderServers, '_get_processing_unit') - @patch.object(MockModelBuilderServers, '_get_smd_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_auto_image(self, mock_create, mock_prepare_mode, mock_get_img, - mock_get_unit, mock_save, mock_smd_prepare): + + @patch("sagemaker.serve.model_builder_servers.prepare_for_smd") + @patch.object(MockModelBuilderServers, "_save_model_inference_spec") + @patch.object(MockModelBuilderServers, "_get_processing_unit") + @patch.object(MockModelBuilderServers, "_get_smd_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_auto_image( + self, + mock_create, + mock_prepare_mode, + mock_get_img, + mock_get_unit, + mock_save, + mock_smd_prepare, + ): """Test building with auto-detected image.""" mock_get_unit.return_value = "gpu" mock_get_img.return_value = "smd-image-uri" - mock_smd_prepare.return_value = "secret999" + mock_smd_prepare.return_value = "" mock_create.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER self.builder.image_uri = None self.builder.model = Mock() - + result = self.builder._build_for_smd() - + self.assertEqual(self.builder.image_uri, "smd-image-uri") - self.assertEqual(self.builder.secret_key, "secret999") + self.assertEqual(self.builder.secret_key, "") mock_create.assert_called_once() class TestBuildForTransformers(unittest.TestCase): """Test _build_for_transformers method.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model_server = ModelServer.MMS - - @patch('sagemaker.serve.model_builder_servers.save_pkl') - @patch('sagemaker.serve.model_builder_servers.prepare_for_mms') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_create_conda_env') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_inference_spec_local_container(self, mock_create, mock_prepare_mode, - mock_conda, mock_detect, mock_dir, - mock_nb, mock_mms_prepare, mock_save): + + @patch("sagemaker.serve.model_builder_servers.save_pkl") + @patch("sagemaker.serve.model_builder_servers.prepare_for_mms") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_create_conda_env") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_inference_spec_local_container( + self, + mock_create, + mock_prepare_mode, + mock_conda, + mock_detect, + mock_dir, + mock_nb, + mock_mms_prepare, + mock_save, + ): """Test building with inference_spec for LOCAL_CONTAINER.""" mock_nb.return_value = None - mock_mms_prepare.return_value = "secret111" + mock_mms_prepare.return_value = "" mock_create.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER self.builder.inference_spec = Mock() - + result = self.builder._build_for_transformers() - + mock_save.assert_called_once() mock_mms_prepare.assert_called_once() - self.assertEqual(self.builder.secret_key, "secret111") + self.assertEqual(self.builder.secret_key, "") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf') - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_is_jumpstart_model_id') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_with_hf_model_string(self, mock_create, mock_prepare, mock_detect, mock_js, - mock_dir, mock_nb, mock_hf_config): + + @patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf") + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_is_jumpstart_model_id") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_with_hf_model_string( + self, mock_create, mock_prepare, mock_detect, mock_js, mock_dir, mock_nb, mock_hf_config + ): """Test building with HuggingFace model string.""" mock_js.return_value = False mock_nb.return_value = None @@ -685,62 +777,66 @@ def test_build_with_hf_model_string(self, mock_create, mock_prepare, mock_detect mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) self.builder.model = "gpt2" self.builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "token"} - + result = self.builder._build_for_transformers() - + self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "gpt2") mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_sagemaker_endpoint_missing_instance_type(self, mock_create, mock_prepare, - mock_detect, mock_dir, mock_nb): + + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_sagemaker_endpoint_missing_instance_type( + self, mock_create, mock_prepare, mock_detect, mock_dir, mock_nb + ): """Test error when instance_type is missing for SAGEMAKER_ENDPOINT.""" mock_nb.return_value = None self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.instance_type = None self.builder.model = Mock() mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) - + with self.assertRaises(ValueError) as ctx: self.builder._build_for_transformers() self.assertIn("Instance type", str(ctx.exception)) - - @patch('sagemaker.serve.model_builder_servers._get_nb_instance') - @patch('sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure') - @patch.object(MockModelBuilderServers, '_auto_detect_image_uri') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_clean_empty_secret_key(self, mock_create, mock_prepare, mock_detect, - mock_dir, mock_nb): + + @patch("sagemaker.serve.model_builder_servers._get_nb_instance") + @patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure") + @patch.object(MockModelBuilderServers, "_auto_detect_image_uri") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_clean_empty_secret_key( + self, mock_create, mock_prepare, mock_detect, mock_dir, mock_nb + ): """Test cleaning empty secret key from env_vars.""" mock_nb.return_value = None mock_create.return_value = Mock() mock_prepare.return_value = ("s3://bucket/model.tar.gz", None) self.builder.model = Mock() self.builder.env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = "" - + result = self.builder._build_for_transformers() - + self.assertNotIn("SAGEMAKER_SERVE_SECRET_KEY", self.builder.env_vars) mock_create.assert_called_once() class TestBuildForJumpStart(unittest.TestCase): """Test _build_for_jumpstart and related methods.""" - + def setUp(self): self.builder = MockModelBuilderServers() self.builder.model = "huggingface-llm-falcon-7b" - - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_djl_local_container(self, mock_create, mock_prepare_mode, mock_djl_res, mock_init): + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder_servers.prepare_djl_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_djl_local_container( + self, mock_create, mock_prepare_mode, mock_djl_res, mock_init + ): """Test building DJL JumpStart model for LOCAL_CONTAINER.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "djl-inference:0.21.0" @@ -751,18 +847,20 @@ def test_build_djl_local_container(self, mock_create, mock_prepare_mode, mock_dj mock_create.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER self.builder.image_uri = None - + result = self.builder._build_for_jumpstart() - + self.assertEqual(self.builder.model_server, ModelServer.DJL_SERVING) self.assertTrue(self.builder.prepared_for_djl) mock_create.assert_called_once() - - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder_servers.prepare_tgi_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_tgi_local_container(self, mock_create, mock_prepare_mode, mock_tgi_res, mock_init): + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder_servers.prepare_tgi_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_tgi_local_container( + self, mock_create, mock_prepare_mode, mock_tgi_res, mock_init + ): """Test building TGI JumpStart model for LOCAL_CONTAINER.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "tgi-inference:1.0.0" @@ -773,18 +871,20 @@ def test_build_tgi_local_container(self, mock_create, mock_prepare_mode, mock_tg mock_create.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER self.builder.image_uri = None - + result = self.builder._build_for_jumpstart() - + self.assertEqual(self.builder.model_server, ModelServer.TGI) self.assertTrue(self.builder.prepared_for_tgi) mock_create.assert_called_once() - - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder_servers.prepare_mms_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_mms_local_container(self, mock_create, mock_prepare_mode, mock_mms_res, mock_init): + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder_servers.prepare_mms_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_mms_local_container( + self, mock_create, mock_prepare_mode, mock_mms_res, mock_init + ): """Test building MMS JumpStart model for LOCAL_CONTAINER.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "huggingface-pytorch-inference:1.10.0" @@ -795,14 +895,14 @@ def test_build_mms_local_container(self, mock_create, mock_prepare_mode, mock_mm mock_create.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER self.builder.image_uri = None - + result = self.builder._build_for_jumpstart() - + self.assertEqual(self.builder.model_server, ModelServer.MMS) self.assertTrue(self.builder.prepared_for_mms) mock_create.assert_called_once() - - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") def test_build_unsupported_image_uri(self, mock_init): """Test error for unsupported JumpStart image URI.""" mock_init_kwargs = Mock() @@ -812,16 +912,18 @@ def test_build_unsupported_image_uri(self, mock_init): mock_init.return_value = mock_init_kwargs self.builder.mode = Mode.LOCAL_CONTAINER self.builder.image_uri = None - + with self.assertRaises(ValueError) as ctx: self.builder._build_for_jumpstart() self.assertIn("Unsupported", str(ctx.exception)) - - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_passes_config_name_to_get_init_kwargs(self, mock_create, mock_prepare_mode, mock_djl_res, mock_init): + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder_servers.prepare_djl_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_passes_config_name_to_get_init_kwargs( + self, mock_create, mock_prepare_mode, mock_djl_res, mock_init + ): """Test that config_name is forwarded to get_init_kwargs.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "djl-inference:0.21.0" @@ -846,11 +948,13 @@ def test_build_passes_config_name_to_get_init_kwargs(self, mock_create, mock_pre config_name="lmi-optimized", ) - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') - def test_build_passes_none_config_name_when_not_set(self, mock_create, mock_prepare_mode, mock_djl_res, mock_init): + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder_servers.prepare_djl_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_passes_none_config_name_when_not_set( + self, mock_create, mock_prepare_mode, mock_djl_res, mock_init + ): """Test that config_name defaults to None when not set.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "djl-inference:0.21.0" @@ -875,9 +979,9 @@ def test_build_passes_none_config_name_when_not_set(self, mock_create, mock_prep config_name=None, ) - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_build_for_djl_jumpstart') + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_build_for_djl_jumpstart") def test_build_sagemaker_endpoint_djl(self, mock_djl_build, mock_prepare, mock_init): """Test building DJL JumpStart for SAGEMAKER_ENDPOINT.""" mock_init_kwargs = Mock() @@ -888,157 +992,154 @@ def test_build_sagemaker_endpoint_djl(self, mock_djl_build, mock_prepare, mock_i mock_djl_build.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.image_uri = None - + result = self.builder._build_for_jumpstart() - + mock_djl_build.assert_called_once() class TestDeployWrappers(unittest.TestCase): """Test deploy wrapper methods.""" - + def setUp(self): self.builder = MockModelBuilderServers() - - @patch.object(MockModelBuilderServers, '_deploy_local_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_local_endpoint") def test_djl_deploy_in_process(self, mock_deploy): """Test DJL deploy wrapper for IN_PROCESS mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.IN_PROCESS - + result = self.builder._djl_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_local_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_local_endpoint") def test_djl_deploy_local_container(self, mock_deploy): """Test DJL deploy wrapper for LOCAL_CONTAINER mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER - + result = self.builder._djl_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_djl_deploy_sagemaker_endpoint(self, mock_deploy): """Test DJL deploy wrapper for SAGEMAKER_ENDPOINT mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT - - result = self.builder._djl_model_builder_deploy_wrapper( - model_data_download_timeout=600 - ) - + + result = self.builder._djl_model_builder_deploy_wrapper(model_data_download_timeout=600) + self.assertEqual(self.builder.env_vars["MODEL_LOADING_TIMEOUT"], "600") mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_djl_deploy_with_defaults(self, mock_deploy): """Test DJL deploy wrapper sets default values.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT - + result = self.builder._djl_model_builder_deploy_wrapper() - + call_kwargs = mock_deploy.call_args[1] self.assertEqual(call_kwargs["endpoint_logging"], True) self.assertEqual(call_kwargs["initial_instance_count"], 1) - - @patch.object(MockModelBuilderServers, '_deploy_local_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_local_endpoint") def test_tgi_deploy_local_container(self, mock_deploy): """Test TGI deploy wrapper for LOCAL_CONTAINER mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER - + result = self.builder._tgi_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_tgi_deploy_sagemaker_endpoint(self, mock_deploy): """Test TGI deploy wrapper for SAGEMAKER_ENDPOINT mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT - + result = self.builder._tgi_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_local_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_local_endpoint") def test_tei_deploy_in_process(self, mock_deploy): """Test TEI deploy wrapper for IN_PROCESS mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.IN_PROCESS - + result = self.builder._tei_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_tei_deploy_sagemaker_endpoint(self, mock_deploy): """Test TEI deploy wrapper for SAGEMAKER_ENDPOINT mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT - + result = self.builder._tei_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_local_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_local_endpoint") def test_js_deploy_local_container(self, mock_deploy): """Test JumpStart deploy wrapper for LOCAL_CONTAINER mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER - + result = self.builder._js_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_js_deploy_sagemaker_endpoint(self, mock_deploy): """Test JumpStart deploy wrapper for SAGEMAKER_ENDPOINT mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.instance_type = "ml.g5.xlarge" - + result = self.builder._js_builder_deploy_wrapper() - + call_kwargs = mock_deploy.call_args[1] self.assertEqual(call_kwargs["instance_type"], "ml.g5.xlarge") mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_local_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_local_endpoint") def test_transformers_deploy_local_container(self, mock_deploy): """Test Transformers deploy wrapper for LOCAL_CONTAINER mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.LOCAL_CONTAINER - + result = self.builder._transformers_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_transformers_deploy_sagemaker_endpoint(self, mock_deploy): """Test Transformers deploy wrapper for SAGEMAKER_ENDPOINT mode.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT - + result = self.builder._transformers_model_builder_deploy_wrapper() - + mock_deploy.assert_called_once() - - @patch.object(MockModelBuilderServers, '_deploy_core_endpoint') + + @patch.object(MockModelBuilderServers, "_deploy_core_endpoint") def test_deploy_wrapper_removes_mode_and_role(self, mock_deploy): """Test deploy wrapper removes mode and role from kwargs.""" mock_deploy.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT - + result = self.builder._djl_model_builder_deploy_wrapper( - mode=Mode.LOCAL_CONTAINER, - role="arn:aws:iam::123456789012:role/test" + mode=Mode.LOCAL_CONTAINER, role="arn:aws:iam::123456789012:role/test" ) - + call_kwargs = mock_deploy.call_args[1] self.assertNotIn("mode", call_kwargs) self.assertNotIn("role", call_kwargs) @@ -1047,13 +1148,13 @@ def test_deploy_wrapper_removes_mode_and_role(self, mock_deploy): class TestJumpStartBuilders(unittest.TestCase): """Test JumpStart-specific builder methods.""" - + def setUp(self): self.builder = MockModelBuilderServers() - - @patch('sagemaker.serve.model_builder_servers.prepare_djl_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') + + @patch("sagemaker.serve.model_builder_servers.prepare_djl_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") def test_build_for_djl_jumpstart_local(self, mock_create, mock_prepare, mock_djl_res): """Test _build_for_djl_jumpstart for local mode.""" mock_init_kwargs = Mock() @@ -1063,15 +1164,15 @@ def test_build_for_djl_jumpstart_local(self, mock_create, mock_prepare, mock_djl self.builder.mode = Mode.LOCAL_CONTAINER self.builder.model = "jumpstart-model-id" self.builder.s3_model_data_url = "s3://bucket/model.tar.gz" - + result = self.builder._build_for_djl_jumpstart(mock_init_kwargs) - + self.assertEqual(self.builder.model_server, ModelServer.DJL_SERVING) self.assertTrue(self.builder.prepared_for_djl) mock_djl_res.assert_called_once() mock_create.assert_called_once() - - @patch.object(MockModelBuilderServers, '_create_model') + + @patch.object(MockModelBuilderServers, "_create_model") def test_build_for_djl_jumpstart_sagemaker(self, mock_create): """Test _build_for_djl_jumpstart for SAGEMAKER_ENDPOINT mode.""" mock_init_kwargs = Mock() @@ -1079,16 +1180,16 @@ def test_build_for_djl_jumpstart_sagemaker(self, mock_create): mock_create.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.model = "jumpstart-model-id" - + result = self.builder._build_for_djl_jumpstart(mock_init_kwargs) - + self.assertEqual(self.builder.s3_upload_path, "s3://bucket/model.tar.gz") self.assertTrue(self.builder.prepared_for_djl) mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers.prepare_tgi_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') + + @patch("sagemaker.serve.model_builder_servers.prepare_tgi_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") def test_build_for_tgi_jumpstart_local(self, mock_create, mock_prepare, mock_tgi_res): """Test _build_for_tgi_jumpstart for local mode.""" mock_init_kwargs = Mock() @@ -1098,17 +1199,17 @@ def test_build_for_tgi_jumpstart_local(self, mock_create, mock_prepare, mock_tgi self.builder.mode = Mode.LOCAL_CONTAINER self.builder.model = "jumpstart-model-id" self.builder.s3_model_data_url = "s3://bucket/model.tar.gz" - + result = self.builder._build_for_tgi_jumpstart(mock_init_kwargs) - + self.assertEqual(self.builder.model_server, ModelServer.TGI) self.assertTrue(self.builder.prepared_for_tgi) mock_tgi_res.assert_called_once() mock_create.assert_called_once() - - @patch('sagemaker.serve.model_builder_servers.prepare_mms_js_resources') - @patch.object(MockModelBuilderServers, '_prepare_for_mode') - @patch.object(MockModelBuilderServers, '_create_model') + + @patch("sagemaker.serve.model_builder_servers.prepare_mms_js_resources") + @patch.object(MockModelBuilderServers, "_prepare_for_mode") + @patch.object(MockModelBuilderServers, "_create_model") def test_build_for_mms_jumpstart_local(self, mock_create, mock_prepare, mock_mms_res): """Test _build_for_mms_jumpstart for local mode.""" mock_init_kwargs = Mock() @@ -1118,9 +1219,9 @@ def test_build_for_mms_jumpstart_local(self, mock_create, mock_prepare, mock_mms self.builder.mode = Mode.LOCAL_CONTAINER self.builder.model = "jumpstart-model-id" self.builder.s3_model_data_url = "s3://bucket/model.tar.gz" - + result = self.builder._build_for_mms_jumpstart(mock_init_kwargs) - + self.assertEqual(self.builder.model_server, ModelServer.MMS) self.assertTrue(self.builder.prepared_for_mms) mock_mms_res.assert_called_once() diff --git a/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py b/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py index bb0d1d874c..85672a8d50 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py +++ b/sagemaker-serve/tests/unit/test_model_builder_utils_triton.py @@ -21,20 +21,21 @@ def test_triton_serializer_init(self): """Test TritonSerializer initialization.""" mock_serializer = Mock() serializer = TritonSerializer(mock_serializer, "FP32") - + self.assertEqual(serializer.dtype, "FP32") self.assertEqual(serializer.input_serializer, mock_serializer) def test_triton_serializer_serialize(self): """Test TritonSerializer serialize method.""" import numpy as np + mock_serializer = Mock() mock_array = np.array([[1, 2, 3]]) mock_serializer.serialize.return_value = mock_array - + serializer = TritonSerializer(mock_serializer, "FP32") result = serializer.serialize(mock_array) - + self.assertIsNotNone(result) @@ -45,8 +46,8 @@ def test_validate_for_triton_missing_tritonclient(self): """Test validation fails without tritonclient - skipped as tritonclient is installed.""" pass - @patch('importlib.util.find_spec') - @patch.object(_ModelBuilderUtils, '_has_nvidia_gpu') + @patch("importlib.util.find_spec") + @patch.object(_ModelBuilderUtils, "_has_nvidia_gpu") def test_validate_for_triton_no_gpu_local(self, mock_has_gpu, mock_find_spec): """Test validation fails for GPU mode without GPU.""" utils = _ModelBuilderUtils() @@ -56,23 +57,23 @@ def test_validate_for_triton_no_gpu_local(self, mock_has_gpu, mock_find_spec): utils.schema_builder = Mock() utils.schema_builder._update_serializer_deserializer_for_triton = Mock() utils.schema_builder._detect_dtype_for_triton = Mock() - + mock_find_spec.return_value = Mock() mock_has_gpu.return_value = False - + with self.assertRaises(ValueError): utils._validate_for_triton() - @patch('importlib.util.find_spec') + @patch("importlib.util.find_spec") def test_validate_for_triton_unsupported_mode(self, mock_find_spec): """Test validation fails for unsupported mode.""" utils = _ModelBuilderUtils() utils.mode = "UNSUPPORTED_MODE" utils.model_path = "/tmp/model" utils.schema_builder = Mock() - + mock_find_spec.return_value = Mock() - + with self.assertRaises(ValueError): utils._validate_for_triton() @@ -80,51 +81,51 @@ def test_validate_for_triton_unsupported_mode(self, mock_find_spec): class TestPrepareForTriton(unittest.TestCase): """Test _prepare_for_triton method.""" - @patch('shutil.copy2') - @patch.object(_ModelBuilderUtils, '_export_pytorch_to_onnx') + @patch("shutil.copy2") + @patch.object(_ModelBuilderUtils, "_export_pytorch_to_onnx") def test_prepare_for_triton_pytorch(self, mock_export, mock_copy): """Test preparing PyTorch model for Triton.""" utils = _ModelBuilderUtils() utils.framework = Framework.PYTORCH utils.model = Mock() utils.schema_builder = Mock() - + with tempfile.TemporaryDirectory() as tmpdir: utils.model_path = tmpdir utils._prepare_for_triton() - + mock_export.assert_called_once() - @patch('shutil.copy2') - @patch.object(_ModelBuilderUtils, '_export_tf_to_onnx') + @patch("shutil.copy2") + @patch.object(_ModelBuilderUtils, "_export_tf_to_onnx") def test_prepare_for_triton_tensorflow(self, mock_export, mock_copy): """Test preparing TensorFlow model for Triton.""" utils = _ModelBuilderUtils() utils.framework = Framework.TENSORFLOW utils.model = Mock() utils.schema_builder = Mock() - + with tempfile.TemporaryDirectory() as tmpdir: utils.model_path = tmpdir utils._prepare_for_triton() - + mock_export.assert_called_once() - @patch('shutil.copy2') - @patch.object(_ModelBuilderUtils, '_generate_config_pbtxt') - @patch.object(_ModelBuilderUtils, '_pack_conda_env') - @patch.object(_ModelBuilderUtils, '_hmac_signing') + @patch("shutil.copy2") + @patch.object(_ModelBuilderUtils, "_generate_config_pbtxt") + @patch.object(_ModelBuilderUtils, "_pack_conda_env") + @patch.object(_ModelBuilderUtils, "_compute_integrity_hash") def test_prepare_for_triton_inference_spec(self, mock_hmac, mock_pack, mock_config, mock_copy): """Test preparing inference spec for Triton.""" utils = _ModelBuilderUtils() utils.inference_spec = Mock() utils.model = None utils.schema_builder = Mock() - + with tempfile.TemporaryDirectory() as tmpdir: utils.model_path = tmpdir utils._prepare_for_triton() - + mock_config.assert_called_once() mock_pack.assert_called_once() mock_hmac.assert_called_once() @@ -133,26 +134,27 @@ def test_prepare_for_triton_inference_spec(self, mock_hmac, mock_pack, mock_conf class TestExportPytorchToOnnx(unittest.TestCase): """Test _export_pytorch_to_onnx method.""" - @patch('torch.onnx.export') + @patch("torch.onnx.export") def test_export_pytorch_to_onnx_success(self, mock_export): """Test successful PyTorch to ONNX export.""" try: import ml_dtypes + # Skip test if ml_dtypes doesn't have required attribute - if not hasattr(ml_dtypes, 'float4_e2m1fn'): + if not hasattr(ml_dtypes, "float4_e2m1fn"): self.skipTest("ml_dtypes version incompatible with current numpy/onnx") except ImportError: pass - + utils = _ModelBuilderUtils() mock_model = Mock() mock_schema = Mock() mock_schema.sample_input = Mock() - + with tempfile.TemporaryDirectory() as tmpdir: export_path = Path(tmpdir) utils._export_pytorch_to_onnx(mock_model, export_path, mock_schema) - + mock_export.assert_called_once() def test_export_pytorch_to_onnx_no_torch(self): @@ -167,7 +169,7 @@ class TestExportTFToOnnx(unittest.TestCase): def test_export_tf_to_onnx_no_tf2onnx(self): """Test TensorFlow export without tf2onnx installed.""" utils = _ModelBuilderUtils() - + # tf2onnx not installed in test environment with tempfile.TemporaryDirectory() as tmpdir: with self.assertRaises(ImportError): @@ -188,11 +190,11 @@ def test_generate_config_pbtxt_cpu(self): utils.schema_builder._sample_output_ndarray.shape = [1, 5] utils.schema_builder._input_triton_dtype = "FP32" utils.schema_builder._output_triton_dtype = "FP32" - + with tempfile.TemporaryDirectory() as tmpdir: pkl_path = Path(tmpdir) utils._generate_config_pbtxt(pkl_path) - + config_path = pkl_path / "config.pbtxt" self.assertTrue(config_path.exists()) content = config_path.read_text() @@ -209,11 +211,11 @@ def test_generate_config_pbtxt_gpu(self): utils.schema_builder._sample_output_ndarray.shape = [1, 5] utils.schema_builder._input_triton_dtype = "FP32" utils.schema_builder._output_triton_dtype = "FP32" - + with tempfile.TemporaryDirectory() as tmpdir: pkl_path = Path(tmpdir) utils._generate_config_pbtxt(pkl_path) - + config_path = pkl_path / "config.pbtxt" self.assertTrue(config_path.exists()) content = config_path.read_text() @@ -226,8 +228,8 @@ class TestPackCondaEnv(unittest.TestCase): def test_pack_conda_env_no_conda_pack(self): """Test packing conda env without conda_pack.""" utils = _ModelBuilderUtils() - - with patch('importlib.util.find_spec', return_value=None): + + with patch("importlib.util.find_spec", return_value=None): with tempfile.TemporaryDirectory() as tmpdir: with self.assertRaises(ImportError): utils._pack_conda_env(Path(tmpdir)) @@ -235,7 +237,7 @@ def test_pack_conda_env_no_conda_pack(self): def test_pack_conda_env_no_conda_pack_real(self): """Test packing conda env without conda_pack - real check.""" utils = _ModelBuilderUtils() - + with tempfile.TemporaryDirectory() as tmpdir: with self.assertRaises(ImportError): utils._pack_conda_env(Path(tmpdir)) @@ -249,37 +251,36 @@ def test_save_inference_spec(self): utils = _ModelBuilderUtils() utils.inference_spec = Mock() utils.schema_builder = Mock() - + with tempfile.TemporaryDirectory() as tmpdir: utils.model_path = tmpdir pkl_path = os.path.join(tmpdir, "model_repository", "model") os.makedirs(pkl_path, exist_ok=True) - + utils._save_inference_spec() - + # Check that serve.pkl was created self.assertTrue(os.path.exists(os.path.join(pkl_path, "serve.pkl"))) class TestHMACSignin(unittest.TestCase): - """Test _hmac_signing method.""" + """Test _compute_integrity_hash method.""" - def test_hmac_signing(self): - """Test HMAC signing.""" + def test_compute_integrity_hash(self): + """Test SHA-256 integrity hash computation.""" utils = _ModelBuilderUtils() - + with tempfile.TemporaryDirectory() as tmpdir: utils.model_path = tmpdir pkl_path = Path(tmpdir) / "model_repository" / "model" pkl_path.mkdir(parents=True) - + # Create dummy serve.pkl (pkl_path / "serve.pkl").write_bytes(b"dummy content") - - utils._hmac_signing() - - # Secret key is generated, not mocked - self.assertIsNotNone(utils.secret_key) + + utils._compute_integrity_hash() + + # metadata.json should be created with the SHA-256 hash self.assertTrue((pkl_path / "metadata.json").exists()) @@ -291,9 +292,9 @@ def test_auto_detect_image_skip_if_provided(self): utils = _ModelBuilderUtils() utils.image_uri = "custom-triton-image" utils.sagemaker_session = Mock() - + utils._auto_detect_image_for_triton() - + self.assertEqual(utils.image_uri, "custom-triton-image") def test_auto_detect_image_cpu_instance(self): @@ -306,9 +307,9 @@ def test_auto_detect_image_cpu_instance(self): utils.inference_spec = None utils.framework = "pytorch" utils.version = "1.13" - + utils._auto_detect_image_for_triton() - + self.assertIsNotNone(utils.image_uri) self.assertIn("-cpu", utils.image_uri) @@ -322,9 +323,9 @@ def test_auto_detect_image_gpu_instance(self): utils.inference_spec = None utils.framework = "pytorch" utils.version = "1.13" - + utils._auto_detect_image_for_triton() - + self.assertIsNotNone(utils.image_uri) self.assertNotIn("-cpu", utils.image_uri) @@ -335,7 +336,7 @@ def test_auto_detect_image_unsupported_region(self): utils.instance_type = "ml.g5.xlarge" utils.sagemaker_session = Mock() utils.sagemaker_session.boto_region_name = "unsupported-region" - + with self.assertRaises(ValueError): utils._auto_detect_image_for_triton() @@ -349,7 +350,7 @@ def test_validate_djl_valid_data(self): utils.schema_builder = Mock() utils.schema_builder.sample_input = {"inputs": "test", "parameters": {}} utils.schema_builder.sample_output = [{"generated_text": "output"}] - + # Should not raise utils._validate_djl_serving_sample_data() @@ -359,7 +360,7 @@ def test_validate_djl_invalid_input(self): utils.schema_builder = Mock() utils.schema_builder.sample_input = {"wrong_key": "test"} utils.schema_builder.sample_output = [{"generated_text": "output"}] - + with self.assertRaises(ValueError): utils._validate_djl_serving_sample_data() @@ -369,7 +370,7 @@ def test_validate_djl_invalid_output(self): utils.schema_builder = Mock() utils.schema_builder.sample_input = {"inputs": "test", "parameters": {}} utils.schema_builder.sample_output = [{"wrong_key": "output"}] - + with self.assertRaises(ValueError): utils._validate_djl_serving_sample_data() @@ -383,7 +384,7 @@ def test_validate_tgi_valid_data(self): utils.schema_builder = Mock() utils.schema_builder.sample_input = {"inputs": "test", "parameters": {}} utils.schema_builder.sample_output = [{"generated_text": "output"}] - + # Should not raise utils._validate_tgi_serving_sample_data() @@ -393,7 +394,7 @@ def test_validate_tgi_invalid_input(self): utils.schema_builder = Mock() utils.schema_builder.sample_input = "invalid" utils.schema_builder.sample_output = [{"generated_text": "output"}] - + with self.assertRaises(ValueError): utils._validate_tgi_serving_sample_data() @@ -401,15 +402,15 @@ def test_validate_tgi_invalid_input(self): class TestCreateCondaEnv(unittest.TestCase): """Test _create_conda_env method.""" - @patch('sagemaker.serve.builder.requirements_manager.RequirementsManager') + @patch("sagemaker.serve.builder.requirements_manager.RequirementsManager") def test_create_conda_env_success(self, mock_req_manager): """Test successful conda env creation.""" utils = _ModelBuilderUtils() mock_manager = Mock() mock_req_manager.return_value = mock_manager - + utils._create_conda_env() - + # Should not raise diff --git a/sagemaker-serve/tests/unit/validations/test_check_integrity.py b/sagemaker-serve/tests/unit/validations/test_check_integrity.py index 11e66eb716..cc05c460bb 100644 --- a/sagemaker-serve/tests/unit/validations/test_check_integrity.py +++ b/sagemaker-serve/tests/unit/validations/test_check_integrity.py @@ -1,39 +1,27 @@ import unittest -import tempfile from pathlib import Path from unittest.mock import patch, mock_open -from sagemaker.serve.validations.check_integrity import ( - generate_secret_key, - compute_hash, - perform_integrity_check -) +from sagemaker.serve.validations.check_integrity import compute_hash, perform_integrity_check class TestCheckIntegrity(unittest.TestCase): - def test_generate_secret_key(self): - key = generate_secret_key() - self.assertIsInstance(key, str) - self.assertEqual(len(key), 64) - - def test_generate_secret_key_custom_bytes(self): - key = generate_secret_key(nbytes=16) - self.assertEqual(len(key), 32) - def test_compute_hash(self): buffer = b"test data" - secret_key = "test_secret" - hash_value = compute_hash(buffer, secret_key) + hash_value = compute_hash(buffer) self.assertIsInstance(hash_value, str) self.assertEqual(len(hash_value), 64) def test_compute_hash_consistency(self): buffer = b"test data" - secret_key = "test_secret" - hash1 = compute_hash(buffer, secret_key) - hash2 = compute_hash(buffer, secret_key) + hash1 = compute_hash(buffer) + hash2 = compute_hash(buffer) self.assertEqual(hash1, hash2) - @patch.dict("os.environ", {"SAGEMAKER_SERVE_SECRET_KEY": "test_key"}) + def test_compute_hash_different_data(self): + hash1 = compute_hash(b"data1") + hash2 = compute_hash(b"data2") + self.assertNotEqual(hash1, hash2) + @patch("pathlib.Path.exists") @patch("builtins.open", new_callable=mock_open, read_data=b'{"sha256_hash": "test_hash"}') @patch("sagemaker.serve.validations.check_integrity._MetaData.from_json") @@ -41,10 +29,14 @@ def test_perform_integrity_check_failure(self, mock_metadata, mock_file, mock_ex mock_exists.return_value = True mock_meta = type("obj", (object,), {"sha256_hash": "wrong_hash"})() mock_metadata.return_value = mock_meta - + with self.assertRaises(ValueError): perform_integrity_check(b"test", Path("/tmp/metadata.json")) + def test_perform_integrity_check_missing_metadata(self): + with self.assertRaises(ValueError, msg="Path to metadata.json does not exist"): + perform_integrity_check(b"test", Path("/nonexistent/metadata.json")) + if __name__ == "__main__": unittest.main()