-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: MLFlow E2E Example Notebook (5513) #5701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,8 @@ | |
| "outputs": [], | ||
| "source": [ | ||
| "# Install fix for MLflow path resolution issues\n", | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "%pip install mlflow==3.4.0" | ||
| "# Using minimum version constraint so the notebook stays compatible with future releases\n", | ||
| "%pip install 'mlflow>=3.4.0'" | ||
| ] | ||
| }, | ||
| { | ||
|
|
@@ -60,6 +61,7 @@ | |
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import json\n", | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "import uuid\n", | ||
| "from sagemaker.core import image_uris\n", | ||
| "from sagemaker.core.helper.session_helper import Session\n", | ||
|
|
@@ -71,7 +73,8 @@ | |
| "MLFLOW_TRACKING_ARN = \"XXXXX\"\n", | ||
| "\n", | ||
| "# AWS Configuration\n", | ||
| "AWS_REGION = Session.boto_region_name\n", | ||
| "sagemaker_session = Session()\n", | ||
| "AWS_REGION = sagemaker_session.boto_region_name\n", | ||
| "\n", | ||
| "# Get PyTorch training image dynamically\n", | ||
| "PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n", | ||
|
|
@@ -304,6 +307,7 @@ | |
| " requirements=\"requirements.txt\",\n", | ||
| " ),\n", | ||
| " base_job_name=training_job_name,\n", | ||
| " sagemaker_session=sagemaker_session,\n", | ||
| ")\n", | ||
| "\n", | ||
| "# Start training job\n", | ||
|
|
@@ -333,9 +337,16 @@ | |
| "from mlflow import MlflowClient\n", | ||
| "\n", | ||
| "client = MlflowClient()\n", | ||
| "registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n", | ||
| "\n", | ||
| "latest_version = registered_model.latest_versions[0]\n", | ||
| "# Use search_model_versions (compatible with MLflow 3.x)\n", | ||
| "# Note: order_by field name may vary across MLflow versions;\n", | ||
| "# 'creation_timestamp' is broadly supported.\n", | ||
| "model_versions = client.search_model_versions(\n", | ||
| " filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: trailing closing paren on if not model_versions:
raise RuntimeError(f"No model versions found for '{MLFLOW_REGISTERED_MODEL_NAME}'")This would give users a clear error message instead of an |
||
| " order_by=['creation_timestamp DESC'],\n", | ||
| " max_results=1\n", | ||
| ")\n", | ||
| "latest_version = model_versions[0]\n", | ||
| "model_version = latest_version.version\n", | ||
| "model_source = latest_version.source\n", | ||
| "\n", | ||
|
|
@@ -366,54 +377,23 @@ | |
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import json\n", | ||
| "import torch\n", | ||
| "from sagemaker.serve.marshalling.custom_payload_translator import CustomPayloadTranslator\n", | ||
| "from sagemaker.serve.builder.schema_builder import SchemaBuilder\n", | ||
| "\n", | ||
| "# =============================================================================\n", | ||
| "# Custom translators for PyTorch tensor conversion\n", | ||
| "# \n", | ||
| "# PyTorch models expect tensors, but SageMaker endpoints communicate via JSON.\n", | ||
| "# These translators handle the conversion between JSON payloads and PyTorch tensors.\n", | ||
| "# Schema Builder for MLflow Model\n", | ||
| "#\n", | ||
| "# When deploying from MLflow Model Registry, the MLflow pyfunc wrapper handles\n", | ||
| "# serialization/deserialization automatically. We only need to provide sample\n", | ||
| "# input/output for schema inference - no custom translators needed.\n", | ||
| "# =============================================================================\n", | ||
| "\n", | ||
| "class PyTorchInputTranslator(CustomPayloadTranslator):\n", | ||
| " \"\"\"Handles input serialization/deserialization for PyTorch models.\"\"\"\n", | ||
| " def __init__(self):\n", | ||
| " super().__init__(content_type='application/json', accept_type='application/json')\n", | ||
| " \n", | ||
| " def serialize_payload_to_bytes(self, payload: object) -> bytes:\n", | ||
| " if isinstance(payload, torch.Tensor):\n", | ||
| " return json.dumps(payload.tolist()).encode('utf-8')\n", | ||
| " return json.dumps(payload).encode('utf-8')\n", | ||
| " \n", | ||
| " def deserialize_payload_from_stream(self, stream) -> object:\n", | ||
| " data = json.load(stream)\n", | ||
| " return torch.tensor(data, dtype=torch.float32)\n", | ||
| "\n", | ||
| "class PyTorchOutputTranslator(CustomPayloadTranslator):\n", | ||
| " \"\"\"Handles output serialization/deserialization for PyTorch models.\"\"\"\n", | ||
| " def __init__(self):\n", | ||
| " super().__init__(content_type='application/json', accept_type='application/json')\n", | ||
| " \n", | ||
| " def serialize_payload_to_bytes(self, payload: object) -> bytes:\n", | ||
| " if isinstance(payload, torch.Tensor):\n", | ||
| " return json.dumps(payload.tolist()).encode('utf-8')\n", | ||
| " return json.dumps(payload).encode('utf-8')\n", | ||
| " \n", | ||
| " def deserialize_payload_from_stream(self, stream) -> object:\n", | ||
| " return json.load(stream)\n", | ||
| "\n", | ||
| "# Sample input/output for schema inference\n", | ||
| "sample_input = [[0.1, 0.2, 0.3, 0.4]]\n", | ||
| "sample_output = [[0.8, 0.2]]\n", | ||
| "\n", | ||
| "schema_builder = SchemaBuilder(\n", | ||
| " sample_input=sample_input,\n", | ||
| " sample_output=sample_output,\n", | ||
| " input_translator=PyTorchInputTranslator(),\n", | ||
| " output_translator=PyTorchOutputTranslator()\n", | ||
| " sample_output=sample_output\n", | ||
| ")" | ||
| ] | ||
| }, | ||
|
|
@@ -434,7 +414,7 @@ | |
| " \"MLFLOW_MODEL_PATH\": mlflow_model_path,\n", | ||
| " \"MLFLOW_TRACKING_ARN\": MLFLOW_TRACKING_ARN\n", | ||
| " },\n", | ||
| " dependencies={\"auto\": False, \"custom\": [\"mlflow==3.4.0\", \"sagemaker==3.3.1\", \"numpy==2.4.1\", \"cloudpickle==3.1.2\"]},\n", | ||
| " dependencies={\"auto\": False, \"custom\": [\"mlflow>=3.4.0\", \"sagemaker>=3.3.1\", \"numpy>=2.4.1\", \"cloudpickle>=3.1.2\"]},\n", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||
| ")\n", | ||
| "\n", | ||
| "print(f\"ModelBuilder configured with MLflow model: {mlflow_model_path}\")" | ||
|
|
@@ -481,19 +461,22 @@ | |
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import boto3\n", | ||
| "\n", | ||
| "# Test with JSON input\n", | ||
| "test_data = [[0.1, 0.2, 0.3, 0.4]]\n", | ||
| "\n", | ||
| "runtime_client = boto3.client('sagemaker-runtime')\n", | ||
| "response = runtime_client.invoke_endpoint(\n", | ||
| " EndpointName=core_endpoint.endpoint_name,\n", | ||
| " Body=json.dumps(test_data),\n", | ||
| " ContentType='application/json'\n", | ||
| "result = core_endpoint.invoke(\n", | ||
| " body=json.dumps(test_data).encode('utf-8'),\n", | ||
| " content_type='application/json'\n", | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ")\n", | ||
| "\n", | ||
| "prediction = json.loads(response['Body'].read().decode('utf-8'))\n", | ||
| "# The invoke() response body may be a streaming object or bytes;\n", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The response body handling logic is reasonable for robustness, but it would be good to add a brief comment or reference to the |
||
| "# handle both cases for robustness.\n", | ||
| "response_body = result.body\n", | ||
| "if hasattr(response_body, 'read'):\n", | ||
| " response_body = response_body.read()\n", | ||
| "if isinstance(response_body, bytes):\n", | ||
| " response_body = response_body.decode('utf-8')\n", | ||
| "prediction = json.loads(response_body)\n", | ||
| "print(f\"Input: {test_data}\")\n", | ||
| "print(f\"Prediction: {prediction}\")" | ||
| ] | ||
|
|
@@ -551,7 +534,9 @@ | |
| "- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry\n", | ||
| "\n", | ||
| "Key patterns:\n", | ||
| "- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n" | ||
| "- MLflow pyfunc handles model serialization automatically\n", | ||
| "- `SchemaBuilder` with sample input/output for schema inference\n", | ||
| "- `core_endpoint.invoke()` for V3-style endpoint invocation\n" | ||
| ] | ||
| } | ||
| ], | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
%pip installcell only installsmlflow>=3.4.0, but the training code'srequirements.txt(referenced in Step 4) presumably still pinsmlflow==3.4.0. The PR description mentions this inconsistency but the diff only shows changes to the notebook cells and thedependenciesdict inModelBuilder. Ifrequirements.txtis a separate file in the repo, it should also be updated to use>=constraints for consistency. Could you confirm whetherrequirements.txtis part of this PR or needs a separate change?