Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion sagemaker-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ dependencies = [
"smdebug_rulesconfig>=1.0.1",
"schema>=0.7.5",
"omegaconf>=2.1.0",
"torch>=1.9.0",
"scipy>=1.5.0",
# Remote function dependencies
"cloudpickle>=2.0.0",
Expand All @@ -57,10 +56,17 @@ codegen = [
"pytest>=8.0.0, <9.0.0",
"pylint>=3.0.0, <4.0.0"
]
torch = [
"torch>=1.9.0",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The all extras group duplicates the torch dependency string. If more optional dependencies are added later, this will need manual sync. Consider referencing the torch extra from all:

all = [
    "sagemaker-core[torch]",
]

This keeps all as a meta-extra that automatically includes everything.

]
all = [
"torch>=1.9.0",
]
test = [
"pytest>=8.0.0, <9.0.0",
"pytest-cov>=4.0.0",
"pytest-xdist>=3.0.0",
"torch>=1.9.0",
]

[project.urls]
Expand Down
5 changes: 4 additions & 1 deletion sagemaker-core/src/sagemaker/core/deserializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,10 @@ def __init__(self, accept="tensor/pt"):

self.convert_npy_to_tensor = from_numpy
except ImportError:
raise Exception("Unable to import pytorch.")
raise ImportError(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The raise ImportError(...) inside an except ImportError block loses the original exception context. Use raise ImportError(...) from e to preserve the exception chain, which helps with debugging:

except ImportError as e:
    raise ImportError(
        "Unable to import torch. Please install torch to use TorchTensorDeserializer: "
        "pip install 'sagemaker-core[torch]'"
    ) from e

"Unable to import torch. Please install torch to use TorchTensorDeserializer: "
"pip install 'sagemaker-core[torch]'"
)

def deserialize(self, stream, content_type="tensor/pt"):
"""Deserialize streamed data to TorchTensor
Expand Down
8 changes: 7 additions & 1 deletion sagemaker-core/src/sagemaker/core/serializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,13 @@ class TorchTensorSerializer(SimpleBaseSerializer):

def __init__(self, content_type="tensor/pt"):
super(TorchTensorSerializer, self).__init__(content_type=content_type)
from torch import Tensor
try:
from torch import Tensor
except ImportError:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue here — use raise ... from e to preserve the exception chain:

except ImportError as e:
    raise ImportError(
        "Unable to import torch. Please install torch to use TorchTensorSerializer: "
        "pip install 'sagemaker-core[torch]'"
    ) from e

raise ImportError(
"Unable to import torch. Please install torch to use TorchTensorSerializer: "
"pip install 'sagemaker-core[torch]'"
)

self.torch_tensor = Tensor
self.numpy_serializer = NumpySerializer()
Expand Down
105 changes: 105 additions & 0 deletions sagemaker-core/tests/unit/test_torch_optional_dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Tests for torch optional dependency behavior."""
from __future__ import absolute_import

import sys
from unittest.mock import patch, MagicMock
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: MagicMock is imported but never used. Remove the unused import.


import numpy as np
import pytest


def test_torch_tensor_serializer_raises_import_error_when_torch_missing():
"""Verify TorchTensorSerializer raises ImportError with helpful message when torch is missing."""
import importlib
import sagemaker.core.serializers.base as base_module

with patch.dict(sys.modules, {"torch": None}):
# Reload to clear any cached imports
importlib.reload(base_module)
with pytest.raises(ImportError, match="pip install 'sagemaker-core\\[torch\\]'"):
base_module.TorchTensorSerializer()

# Reload again to restore normal state
importlib.reload(base_module)


def test_torch_tensor_deserializer_raises_import_error_when_torch_missing():
"""Verify TorchTensorDeserializer raises ImportError with helpful message when torch is missing."""
import importlib
import sagemaker.core.deserializers.base as base_module

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring exceeds the 100-character line length limit. Consider wrapping it:

def test_torch_tensor_deserializer_raises_import_error_when_torch_missing():
    """Verify TorchTensorDeserializer raises ImportError when torch is missing."""

with patch.dict(sys.modules, {"torch": None}):
importlib.reload(base_module)
with pytest.raises(ImportError, match="pip install 'sagemaker-core\\[torch\\]'"):
base_module.TorchTensorDeserializer()

# Reload again to restore normal state
importlib.reload(base_module)


def test_torch_tensor_serializer_works_when_torch_installed():
"""Verify TorchTensorSerializer can be instantiated when torch is available."""
from sagemaker.core.serializers.base import TorchTensorSerializer

serializer = TorchTensorSerializer()
assert serializer is not None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests (test_torch_tensor_serializer_works_when_torch_installed and test_torch_tensor_deserializer_works_when_torch_installed) will fail in CI environments where torch is not installed. Since torch is now optional, the test environment may not have it. Consider guarding these with pytest.importorskip("torch") at the top of each test:

def test_torch_tensor_serializer_works_when_torch_installed():
    pytest.importorskip("torch")
    ...

assert serializer.CONTENT_TYPE == "tensor/pt"


def test_torch_tensor_deserializer_works_when_torch_installed():
"""Verify TorchTensorDeserializer can be instantiated when torch is available."""
from sagemaker.core.deserializers.base import TorchTensorDeserializer

deserializer = TorchTensorDeserializer()
assert deserializer is not None
assert deserializer.ACCEPT == ("tensor/pt",)


def test_sagemaker_core_imports_without_torch():
"""Verify that importing serializers/deserializers modules does not fail without torch."""
import importlib
import sagemaker.core.serializers.base as ser_base
import sagemaker.core.deserializers.base as deser_base

with patch.dict(sys.modules, {"torch": None}):
# Reloading the modules should not raise since torch imports are lazy (in __init__)
importlib.reload(ser_base)
importlib.reload(deser_base)

# Restore
importlib.reload(ser_base)
importlib.reload(deser_base)


def test_other_serializers_work_without_torch():
"""Verify non-torch serializers work normally even if torch is unavailable."""
import importlib
import sagemaker.core.serializers.base as base_module

with patch.dict(sys.modules, {"torch": None}):
importlib.reload(base_module)

csv_ser = base_module.CSVSerializer()
assert csv_ser.serialize([1, 2, 3]) == "1,2,3"

json_ser = base_module.JSONSerializer()
assert json_ser.serialize([1, 2, 3]) == "[1, 2, 3]"

numpy_ser = base_module.NumpySerializer()
result = numpy_ser.serialize(np.array([1, 2, 3]))
assert result is not None

# Restore
importlib.reload(base_module)
Loading