Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
"outputs": [],
"source": [
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The %pip install cell only installs mlflow>=3.4.0, but the training code's requirements.txt (referenced in Step 4) presumably still pins mlflow==3.4.0. The PR description mentions this inconsistency but the diff only shows changes to the notebook cells and the dependencies dict in ModelBuilder. If requirements.txt is a separate file in the repo, it should also be updated to use >= constraints for consistency. Could you confirm whether requirements.txt is part of this PR or needs a separate change?

"# Install fix for MLflow path resolution issues\n",
"%pip install mlflow==3.4.0"
"# Using minimum version constraint so the notebook stays compatible with future releases\n",
"%pip install 'mlflow>=3.4.0'"
]
},
{
Expand All @@ -60,6 +61,7 @@
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import uuid\n",
"from sagemaker.core import image_uris\n",
"from sagemaker.core.helper.session_helper import Session\n",
Expand All @@ -71,7 +73,8 @@
"MLFLOW_TRACKING_ARN = \"XXXXX\"\n",
"\n",
"# AWS Configuration\n",
"AWS_REGION = Session.boto_region_name\n",
"sagemaker_session = Session()\n",
"AWS_REGION = sagemaker_session.boto_region_name\n",
"\n",
"# Get PyTorch training image dynamically\n",
"PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n",
Expand Down Expand Up @@ -304,6 +307,7 @@
" requirements=\"requirements.txt\",\n",
" ),\n",
" base_job_name=training_job_name,\n",
" sagemaker_session=sagemaker_session,\n",
")\n",
"\n",
"# Start training job\n",
Expand Down Expand Up @@ -333,9 +337,16 @@
"from mlflow import MlflowClient\n",
"\n",
"client = MlflowClient()\n",
"registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n",
"\n",
"latest_version = registered_model.latest_versions[0]\n",
"# Use search_model_versions (compatible with MLflow 3.x)\n",
"# Note: order_by field name may vary across MLflow versions;\n",
"# 'creation_timestamp' is broadly supported.\n",
"model_versions = client.search_model_versions(\n",
" filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: trailing closing paren on max_results=1 line is missing a trailing comma, and the list ['creation_timestamp DESC'] uses single quotes inside double-quoted JSON string — this is fine for Python but worth noting for consistency. More importantly, consider adding a guard for the case where model_versions is empty (no versions registered yet), e.g.:

if not model_versions:
    raise RuntimeError(f"No model versions found for '{MLFLOW_REGISTERED_MODEL_NAME}'")

This would give users a clear error message instead of an IndexError.

" order_by=['creation_timestamp DESC'],\n",
" max_results=1\n",
")\n",
"latest_version = model_versions[0]\n",
"model_version = latest_version.version\n",
"model_source = latest_version.source\n",
"\n",
Expand Down Expand Up @@ -366,54 +377,23 @@
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import torch\n",
"from sagemaker.serve.marshalling.custom_payload_translator import CustomPayloadTranslator\n",
"from sagemaker.serve.builder.schema_builder import SchemaBuilder\n",
"\n",
"# =============================================================================\n",
"# Custom translators for PyTorch tensor conversion\n",
"# \n",
"# PyTorch models expect tensors, but SageMaker endpoints communicate via JSON.\n",
"# These translators handle the conversion between JSON payloads and PyTorch tensors.\n",
"# Schema Builder for MLflow Model\n",
"#\n",
"# When deploying from MLflow Model Registry, the MLflow pyfunc wrapper handles\n",
"# serialization/deserialization automatically. We only need to provide sample\n",
"# input/output for schema inference - no custom translators needed.\n",
"# =============================================================================\n",
"\n",
"class PyTorchInputTranslator(CustomPayloadTranslator):\n",
" \"\"\"Handles input serialization/deserialization for PyTorch models.\"\"\"\n",
" def __init__(self):\n",
" super().__init__(content_type='application/json', accept_type='application/json')\n",
" \n",
" def serialize_payload_to_bytes(self, payload: object) -> bytes:\n",
" if isinstance(payload, torch.Tensor):\n",
" return json.dumps(payload.tolist()).encode('utf-8')\n",
" return json.dumps(payload).encode('utf-8')\n",
" \n",
" def deserialize_payload_from_stream(self, stream) -> object:\n",
" data = json.load(stream)\n",
" return torch.tensor(data, dtype=torch.float32)\n",
"\n",
"class PyTorchOutputTranslator(CustomPayloadTranslator):\n",
" \"\"\"Handles output serialization/deserialization for PyTorch models.\"\"\"\n",
" def __init__(self):\n",
" super().__init__(content_type='application/json', accept_type='application/json')\n",
" \n",
" def serialize_payload_to_bytes(self, payload: object) -> bytes:\n",
" if isinstance(payload, torch.Tensor):\n",
" return json.dumps(payload.tolist()).encode('utf-8')\n",
" return json.dumps(payload).encode('utf-8')\n",
" \n",
" def deserialize_payload_from_stream(self, stream) -> object:\n",
" return json.load(stream)\n",
"\n",
"# Sample input/output for schema inference\n",
"sample_input = [[0.1, 0.2, 0.3, 0.4]]\n",
"sample_output = [[0.8, 0.2]]\n",
"\n",
"schema_builder = SchemaBuilder(\n",
" sample_input=sample_input,\n",
" sample_output=sample_output,\n",
" input_translator=PyTorchInputTranslator(),\n",
" output_translator=PyTorchOutputTranslator()\n",
" sample_output=sample_output\n",
")"
]
},
Expand All @@ -434,7 +414,7 @@
" \"MLFLOW_MODEL_PATH\": mlflow_model_path,\n",
" \"MLFLOW_TRACKING_ARN\": MLFLOW_TRACKING_ARN\n",
" },\n",
" dependencies={\"auto\": False, \"custom\": [\"mlflow==3.4.0\", \"sagemaker==3.3.1\", \"numpy==2.4.1\", \"cloudpickle==3.1.2\"]},\n",
" dependencies={\"auto\": False, \"custom\": [\"mlflow>=3.4.0\", \"sagemaker>=3.3.1\", \"numpy>=2.4.1\", \"cloudpickle>=3.1.2\"]},\n",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using >= without an upper bound for deployment dependencies is risky — these packages are installed at inference time inside the container, and a future major version bump of mlflow, sagemaker, numpy, or cloudpickle could break the deployed endpoint silently. Consider using compatible-release constraints (e.g., mlflow>=3.4.0,<4, numpy>=2.4.1,<3, cloudpickle>=3.1.2,<4) to balance freshness with stability. This is especially important for sagemaker which had a major v2→v3 transition.

")\n",
"\n",
"print(f\"ModelBuilder configured with MLflow model: {mlflow_model_path}\")"
Expand Down Expand Up @@ -481,19 +461,22 @@
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"\n",
"# Test with JSON input\n",
"test_data = [[0.1, 0.2, 0.3, 0.4]]\n",
"\n",
"runtime_client = boto3.client('sagemaker-runtime')\n",
"response = runtime_client.invoke_endpoint(\n",
" EndpointName=core_endpoint.endpoint_name,\n",
" Body=json.dumps(test_data),\n",
" ContentType='application/json'\n",
"result = core_endpoint.invoke(\n",
" body=json.dumps(test_data).encode('utf-8'),\n",
" content_type='application/json'\n",
")\n",
"\n",
"prediction = json.loads(response['Body'].read().decode('utf-8'))\n",
"# The invoke() response body may be a streaming object or bytes;\n",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The response body handling logic is reasonable for robustness, but it would be good to add a brief comment or reference to the core_endpoint.invoke() return type from sagemaker-core so future maintainers know what to expect. Also, consider whether result.body could be str already — the current code handles bytes and stream but not str, though in practice it's likely always one of the first two.

"# handle both cases for robustness.\n",
"response_body = result.body\n",
"if hasattr(response_body, 'read'):\n",
" response_body = response_body.read()\n",
"if isinstance(response_body, bytes):\n",
" response_body = response_body.decode('utf-8')\n",
"prediction = json.loads(response_body)\n",
"print(f\"Input: {test_data}\")\n",
"print(f\"Prediction: {prediction}\")"
]
Expand Down Expand Up @@ -551,7 +534,9 @@
"- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry\n",
"\n",
"Key patterns:\n",
"- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n"
"- MLflow pyfunc handles model serialization automatically\n",
"- `SchemaBuilder` with sample input/output for schema inference\n",
"- `core_endpoint.invoke()` for V3-style endpoint invocation\n"
]
}
],
Expand Down
Loading