From 769d06e69fda8a44075c3a1e74069f731bad7a0f Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 31 Mar 2026 19:12:26 -0400 Subject: [PATCH 1/2] feature: Torch dependency in sagameker-core to be made optional (5457) --- sagemaker-core/pyproject.toml | 8 +- .../src/sagemaker/core/deserializers/base.py | 5 +- .../src/sagemaker/core/serializers/base.py | 8 +- .../unit/test_torch_optional_dependency.py | 105 ++++++++++++++++++ 4 files changed, 123 insertions(+), 3 deletions(-) create mode 100644 sagemaker-core/tests/unit/test_torch_optional_dependency.py diff --git a/sagemaker-core/pyproject.toml b/sagemaker-core/pyproject.toml index 2756ce0f1c..53b6857d47 100644 --- a/sagemaker-core/pyproject.toml +++ b/sagemaker-core/pyproject.toml @@ -32,7 +32,6 @@ dependencies = [ "smdebug_rulesconfig>=1.0.1", "schema>=0.7.5", "omegaconf>=2.1.0", - "torch>=1.9.0", "scipy>=1.5.0", # Remote function dependencies "cloudpickle>=2.0.0", @@ -57,10 +56,17 @@ codegen = [ "pytest>=8.0.0, <9.0.0", "pylint>=3.0.0, <4.0.0" ] +torch = [ + "torch>=1.9.0", +] +all = [ + "torch>=1.9.0", +] test = [ "pytest>=8.0.0, <9.0.0", "pytest-cov>=4.0.0", "pytest-xdist>=3.0.0", + "torch>=1.9.0", ] [project.urls] diff --git a/sagemaker-core/src/sagemaker/core/deserializers/base.py b/sagemaker-core/src/sagemaker/core/deserializers/base.py index 4faae7db74..1f7ec9ab06 100644 --- a/sagemaker-core/src/sagemaker/core/deserializers/base.py +++ b/sagemaker-core/src/sagemaker/core/deserializers/base.py @@ -366,7 +366,10 @@ def __init__(self, accept="tensor/pt"): self.convert_npy_to_tensor = from_numpy except ImportError: - raise Exception("Unable to import pytorch.") + raise ImportError( + "Unable to import torch. Please install torch to use TorchTensorDeserializer: " + "pip install 'sagemaker-core[torch]'" + ) def deserialize(self, stream, content_type="tensor/pt"): """Deserialize streamed data to TorchTensor diff --git a/sagemaker-core/src/sagemaker/core/serializers/base.py b/sagemaker-core/src/sagemaker/core/serializers/base.py index a4ecf7c1dc..4b3ba4fdba 100644 --- a/sagemaker-core/src/sagemaker/core/serializers/base.py +++ b/sagemaker-core/src/sagemaker/core/serializers/base.py @@ -443,7 +443,13 @@ class TorchTensorSerializer(SimpleBaseSerializer): def __init__(self, content_type="tensor/pt"): super(TorchTensorSerializer, self).__init__(content_type=content_type) - from torch import Tensor + try: + from torch import Tensor + except ImportError: + raise ImportError( + "Unable to import torch. Please install torch to use TorchTensorSerializer: " + "pip install 'sagemaker-core[torch]'" + ) self.torch_tensor = Tensor self.numpy_serializer = NumpySerializer() diff --git a/sagemaker-core/tests/unit/test_torch_optional_dependency.py b/sagemaker-core/tests/unit/test_torch_optional_dependency.py new file mode 100644 index 0000000000..51ae6f3571 --- /dev/null +++ b/sagemaker-core/tests/unit/test_torch_optional_dependency.py @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for torch optional dependency behavior.""" +from __future__ import absolute_import + +import sys +from unittest.mock import patch, MagicMock + +import numpy as np +import pytest + + +def test_torch_tensor_serializer_raises_import_error_when_torch_missing(): + """Verify TorchTensorSerializer raises ImportError with helpful message when torch is missing.""" + import importlib + import sagemaker.core.serializers.base as base_module + + with patch.dict(sys.modules, {"torch": None}): + # Reload to clear any cached imports + importlib.reload(base_module) + with pytest.raises(ImportError, match="pip install 'sagemaker-core\\[torch\\]'"): + base_module.TorchTensorSerializer() + + # Reload again to restore normal state + importlib.reload(base_module) + + +def test_torch_tensor_deserializer_raises_import_error_when_torch_missing(): + """Verify TorchTensorDeserializer raises ImportError with helpful message when torch is missing.""" + import importlib + import sagemaker.core.deserializers.base as base_module + + with patch.dict(sys.modules, {"torch": None}): + importlib.reload(base_module) + with pytest.raises(ImportError, match="pip install 'sagemaker-core\\[torch\\]'"): + base_module.TorchTensorDeserializer() + + # Reload again to restore normal state + importlib.reload(base_module) + + +def test_torch_tensor_serializer_works_when_torch_installed(): + """Verify TorchTensorSerializer can be instantiated when torch is available.""" + from sagemaker.core.serializers.base import TorchTensorSerializer + + serializer = TorchTensorSerializer() + assert serializer is not None + assert serializer.CONTENT_TYPE == "tensor/pt" + + +def test_torch_tensor_deserializer_works_when_torch_installed(): + """Verify TorchTensorDeserializer can be instantiated when torch is available.""" + from sagemaker.core.deserializers.base import TorchTensorDeserializer + + deserializer = TorchTensorDeserializer() + assert deserializer is not None + assert deserializer.ACCEPT == ("tensor/pt",) + + +def test_sagemaker_core_imports_without_torch(): + """Verify that importing serializers/deserializers modules does not fail without torch.""" + import importlib + import sagemaker.core.serializers.base as ser_base + import sagemaker.core.deserializers.base as deser_base + + with patch.dict(sys.modules, {"torch": None}): + # Reloading the modules should not raise since torch imports are lazy (in __init__) + importlib.reload(ser_base) + importlib.reload(deser_base) + + # Restore + importlib.reload(ser_base) + importlib.reload(deser_base) + + +def test_other_serializers_work_without_torch(): + """Verify non-torch serializers work normally even if torch is unavailable.""" + import importlib + import sagemaker.core.serializers.base as base_module + + with patch.dict(sys.modules, {"torch": None}): + importlib.reload(base_module) + + csv_ser = base_module.CSVSerializer() + assert csv_ser.serialize([1, 2, 3]) == "1,2,3" + + json_ser = base_module.JSONSerializer() + assert json_ser.serialize([1, 2, 3]) == "[1, 2, 3]" + + numpy_ser = base_module.NumpySerializer() + result = numpy_ser.serialize(np.array([1, 2, 3])) + assert result is not None + + # Restore + importlib.reload(base_module) From 2cc8af95da5d90c803fa995c5750aa713ddc8dc6 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 31 Mar 2026 19:20:17 -0400 Subject: [PATCH 2/2] fix: address review comments (iteration #1) --- sagemaker-core/pyproject.toml | 2 +- sagemaker-core/src/sagemaker/core/deserializers/base.py | 4 ++-- sagemaker-core/src/sagemaker/core/serializers/base.py | 4 ++-- sagemaker-core/tests/unit/test_torch_optional_dependency.py | 6 ++++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sagemaker-core/pyproject.toml b/sagemaker-core/pyproject.toml index 53b6857d47..bd702788a3 100644 --- a/sagemaker-core/pyproject.toml +++ b/sagemaker-core/pyproject.toml @@ -60,7 +60,7 @@ torch = [ "torch>=1.9.0", ] all = [ - "torch>=1.9.0", + "sagemaker-core[torch]", ] test = [ "pytest>=8.0.0, <9.0.0", diff --git a/sagemaker-core/src/sagemaker/core/deserializers/base.py b/sagemaker-core/src/sagemaker/core/deserializers/base.py index 1f7ec9ab06..03138ed577 100644 --- a/sagemaker-core/src/sagemaker/core/deserializers/base.py +++ b/sagemaker-core/src/sagemaker/core/deserializers/base.py @@ -365,11 +365,11 @@ def __init__(self, accept="tensor/pt"): from torch import from_numpy self.convert_npy_to_tensor = from_numpy - except ImportError: + except ImportError as e: raise ImportError( "Unable to import torch. Please install torch to use TorchTensorDeserializer: " "pip install 'sagemaker-core[torch]'" - ) + ) from e def deserialize(self, stream, content_type="tensor/pt"): """Deserialize streamed data to TorchTensor diff --git a/sagemaker-core/src/sagemaker/core/serializers/base.py b/sagemaker-core/src/sagemaker/core/serializers/base.py index 4b3ba4fdba..e8862b66f3 100644 --- a/sagemaker-core/src/sagemaker/core/serializers/base.py +++ b/sagemaker-core/src/sagemaker/core/serializers/base.py @@ -445,11 +445,11 @@ def __init__(self, content_type="tensor/pt"): super(TorchTensorSerializer, self).__init__(content_type=content_type) try: from torch import Tensor - except ImportError: + except ImportError as e: raise ImportError( "Unable to import torch. Please install torch to use TorchTensorSerializer: " "pip install 'sagemaker-core[torch]'" - ) + ) from e self.torch_tensor = Tensor self.numpy_serializer = NumpySerializer() diff --git a/sagemaker-core/tests/unit/test_torch_optional_dependency.py b/sagemaker-core/tests/unit/test_torch_optional_dependency.py index 51ae6f3571..42d0466abd 100644 --- a/sagemaker-core/tests/unit/test_torch_optional_dependency.py +++ b/sagemaker-core/tests/unit/test_torch_optional_dependency.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import sys -from unittest.mock import patch, MagicMock +from unittest.mock import patch import numpy as np import pytest @@ -36,7 +36,7 @@ def test_torch_tensor_serializer_raises_import_error_when_torch_missing(): def test_torch_tensor_deserializer_raises_import_error_when_torch_missing(): - """Verify TorchTensorDeserializer raises ImportError with helpful message when torch is missing.""" + """Verify TorchTensorDeserializer raises ImportError when torch is missing.""" import importlib import sagemaker.core.deserializers.base as base_module @@ -51,6 +51,7 @@ def test_torch_tensor_deserializer_raises_import_error_when_torch_missing(): def test_torch_tensor_serializer_works_when_torch_installed(): """Verify TorchTensorSerializer can be instantiated when torch is available.""" + pytest.importorskip("torch") from sagemaker.core.serializers.base import TorchTensorSerializer serializer = TorchTensorSerializer() @@ -60,6 +61,7 @@ def test_torch_tensor_serializer_works_when_torch_installed(): def test_torch_tensor_deserializer_works_when_torch_installed(): """Verify TorchTensorDeserializer can be instantiated when torch is available.""" + pytest.importorskip("torch") from sagemaker.core.deserializers.base import TorchTensorDeserializer deserializer = TorchTensorDeserializer()