diff --git a/v3-examples/ml-ops-examples/v3-mlflow-train-inference-e2e-example.ipynb b/v3-examples/ml-ops-examples/v3-mlflow-train-inference-e2e-example.ipynb index 80435c9325..e8457f4561 100644 --- a/v3-examples/ml-ops-examples/v3-mlflow-train-inference-e2e-example.ipynb +++ b/v3-examples/ml-ops-examples/v3-mlflow-train-inference-e2e-example.ipynb @@ -35,7 +35,8 @@ "outputs": [], "source": [ "# Install fix for MLflow path resolution issues\n", - "%pip install mlflow==3.4.0" + "# Using minimum version constraint so the notebook stays compatible with future releases\n", + "%pip install 'mlflow>=3.4.0'" ] }, { @@ -60,6 +61,7 @@ "metadata": {}, "outputs": [], "source": [ + "import json\n", "import uuid\n", "from sagemaker.core import image_uris\n", "from sagemaker.core.helper.session_helper import Session\n", @@ -71,7 +73,8 @@ "MLFLOW_TRACKING_ARN = \"XXXXX\"\n", "\n", "# AWS Configuration\n", - "AWS_REGION = Session.boto_region_name\n", + "sagemaker_session = Session()\n", + "AWS_REGION = sagemaker_session.boto_region_name\n", "\n", "# Get PyTorch training image dynamically\n", "PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n", @@ -304,6 +307,7 @@ " requirements=\"requirements.txt\",\n", " ),\n", " base_job_name=training_job_name,\n", + " sagemaker_session=sagemaker_session,\n", ")\n", "\n", "# Start training job\n", @@ -333,9 +337,16 @@ "from mlflow import MlflowClient\n", "\n", "client = MlflowClient()\n", - "registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n", "\n", - "latest_version = registered_model.latest_versions[0]\n", + "# Use search_model_versions (compatible with MLflow 3.x)\n", + "# Note: order_by field name may vary across MLflow versions;\n", + "# 'creation_timestamp' is broadly supported.\n", + "model_versions = client.search_model_versions(\n", + " filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n", + " order_by=['creation_timestamp DESC'],\n", + " max_results=1\n", + ")\n", + "latest_version = model_versions[0]\n", "model_version = latest_version.version\n", "model_source = latest_version.source\n", "\n", @@ -366,54 +377,23 @@ "metadata": {}, "outputs": [], "source": [ - "import json\n", - "import torch\n", - "from sagemaker.serve.marshalling.custom_payload_translator import CustomPayloadTranslator\n", "from sagemaker.serve.builder.schema_builder import SchemaBuilder\n", "\n", "# =============================================================================\n", - "# Custom translators for PyTorch tensor conversion\n", - "# \n", - "# PyTorch models expect tensors, but SageMaker endpoints communicate via JSON.\n", - "# These translators handle the conversion between JSON payloads and PyTorch tensors.\n", + "# Schema Builder for MLflow Model\n", + "#\n", + "# When deploying from MLflow Model Registry, the MLflow pyfunc wrapper handles\n", + "# serialization/deserialization automatically. We only need to provide sample\n", + "# input/output for schema inference - no custom translators needed.\n", "# =============================================================================\n", "\n", - "class PyTorchInputTranslator(CustomPayloadTranslator):\n", - " \"\"\"Handles input serialization/deserialization for PyTorch models.\"\"\"\n", - " def __init__(self):\n", - " super().__init__(content_type='application/json', accept_type='application/json')\n", - " \n", - " def serialize_payload_to_bytes(self, payload: object) -> bytes:\n", - " if isinstance(payload, torch.Tensor):\n", - " return json.dumps(payload.tolist()).encode('utf-8')\n", - " return json.dumps(payload).encode('utf-8')\n", - " \n", - " def deserialize_payload_from_stream(self, stream) -> object:\n", - " data = json.load(stream)\n", - " return torch.tensor(data, dtype=torch.float32)\n", - "\n", - "class PyTorchOutputTranslator(CustomPayloadTranslator):\n", - " \"\"\"Handles output serialization/deserialization for PyTorch models.\"\"\"\n", - " def __init__(self):\n", - " super().__init__(content_type='application/json', accept_type='application/json')\n", - " \n", - " def serialize_payload_to_bytes(self, payload: object) -> bytes:\n", - " if isinstance(payload, torch.Tensor):\n", - " return json.dumps(payload.tolist()).encode('utf-8')\n", - " return json.dumps(payload).encode('utf-8')\n", - " \n", - " def deserialize_payload_from_stream(self, stream) -> object:\n", - " return json.load(stream)\n", - "\n", "# Sample input/output for schema inference\n", "sample_input = [[0.1, 0.2, 0.3, 0.4]]\n", "sample_output = [[0.8, 0.2]]\n", "\n", "schema_builder = SchemaBuilder(\n", " sample_input=sample_input,\n", - " sample_output=sample_output,\n", - " input_translator=PyTorchInputTranslator(),\n", - " output_translator=PyTorchOutputTranslator()\n", + " sample_output=sample_output\n", ")" ] }, @@ -434,7 +414,7 @@ " \"MLFLOW_MODEL_PATH\": mlflow_model_path,\n", " \"MLFLOW_TRACKING_ARN\": MLFLOW_TRACKING_ARN\n", " },\n", - " dependencies={\"auto\": False, \"custom\": [\"mlflow==3.4.0\", \"sagemaker==3.3.1\", \"numpy==2.4.1\", \"cloudpickle==3.1.2\"]},\n", + " dependencies={\"auto\": False, \"custom\": [\"mlflow>=3.4.0\", \"sagemaker>=3.3.1\", \"numpy>=2.4.1\", \"cloudpickle>=3.1.2\"]},\n", ")\n", "\n", "print(f\"ModelBuilder configured with MLflow model: {mlflow_model_path}\")" @@ -481,19 +461,22 @@ "metadata": {}, "outputs": [], "source": [ - "import boto3\n", - "\n", "# Test with JSON input\n", "test_data = [[0.1, 0.2, 0.3, 0.4]]\n", "\n", - "runtime_client = boto3.client('sagemaker-runtime')\n", - "response = runtime_client.invoke_endpoint(\n", - " EndpointName=core_endpoint.endpoint_name,\n", - " Body=json.dumps(test_data),\n", - " ContentType='application/json'\n", + "result = core_endpoint.invoke(\n", + " body=json.dumps(test_data).encode('utf-8'),\n", + " content_type='application/json'\n", ")\n", "\n", - "prediction = json.loads(response['Body'].read().decode('utf-8'))\n", + "# The invoke() response body may be a streaming object or bytes;\n", + "# handle both cases for robustness.\n", + "response_body = result.body\n", + "if hasattr(response_body, 'read'):\n", + " response_body = response_body.read()\n", + "if isinstance(response_body, bytes):\n", + " response_body = response_body.decode('utf-8')\n", + "prediction = json.loads(response_body)\n", "print(f\"Input: {test_data}\")\n", "print(f\"Prediction: {prediction}\")" ] @@ -551,7 +534,9 @@ "- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry\n", "\n", "Key patterns:\n", - "- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n" + "- MLflow pyfunc handles model serialization automatically\n", + "- `SchemaBuilder` with sample input/output for schema inference\n", + "- `core_endpoint.invoke()` for V3-style endpoint invocation\n" ] } ],