-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Torch dependency in sagameker-core to be made optional (5457) #5707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -366,7 +366,10 @@ def __init__(self, accept="tensor/pt"): | |
|
|
||
| self.convert_npy_to_tensor = from_numpy | ||
| except ImportError: | ||
| raise Exception("Unable to import pytorch.") | ||
| raise ImportError( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The except ImportError as e:
raise ImportError(
"Unable to import torch. Please install torch to use TorchTensorDeserializer: "
"pip install 'sagemaker-core[torch]'"
) from e |
||
| "Unable to import torch. Please install torch to use TorchTensorDeserializer: " | ||
| "pip install 'sagemaker-core[torch]'" | ||
| ) | ||
|
|
||
| def deserialize(self, stream, content_type="tensor/pt"): | ||
| """Deserialize streamed data to TorchTensor | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -443,7 +443,13 @@ class TorchTensorSerializer(SimpleBaseSerializer): | |
|
|
||
| def __init__(self, content_type="tensor/pt"): | ||
| super(TorchTensorSerializer, self).__init__(content_type=content_type) | ||
| from torch import Tensor | ||
| try: | ||
| from torch import Tensor | ||
| except ImportError: | ||
|
||
| raise ImportError( | ||
| "Unable to import torch. Please install torch to use TorchTensorSerializer: " | ||
| "pip install 'sagemaker-core[torch]'" | ||
| ) | ||
|
|
||
| self.torch_tensor = Tensor | ||
| self.numpy_serializer = NumpySerializer() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """Tests for torch optional dependency behavior.""" | ||
| from __future__ import absolute_import | ||
|
|
||
| import sys | ||
| from unittest.mock import patch, MagicMock | ||
|
||
|
|
||
| import numpy as np | ||
| import pytest | ||
|
|
||
|
|
||
| def test_torch_tensor_serializer_raises_import_error_when_torch_missing(): | ||
| """Verify TorchTensorSerializer raises ImportError with helpful message when torch is missing.""" | ||
| import importlib | ||
| import sagemaker.core.serializers.base as base_module | ||
|
|
||
| with patch.dict(sys.modules, {"torch": None}): | ||
| # Reload to clear any cached imports | ||
| importlib.reload(base_module) | ||
| with pytest.raises(ImportError, match="pip install 'sagemaker-core\\[torch\\]'"): | ||
| base_module.TorchTensorSerializer() | ||
|
|
||
| # Reload again to restore normal state | ||
| importlib.reload(base_module) | ||
|
|
||
|
|
||
| def test_torch_tensor_deserializer_raises_import_error_when_torch_missing(): | ||
| """Verify TorchTensorDeserializer raises ImportError with helpful message when torch is missing.""" | ||
| import importlib | ||
| import sagemaker.core.deserializers.base as base_module | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This docstring exceeds the 100-character line length limit. Consider wrapping it: def test_torch_tensor_deserializer_raises_import_error_when_torch_missing():
"""Verify TorchTensorDeserializer raises ImportError when torch is missing.""" |
||
| with patch.dict(sys.modules, {"torch": None}): | ||
| importlib.reload(base_module) | ||
| with pytest.raises(ImportError, match="pip install 'sagemaker-core\\[torch\\]'"): | ||
| base_module.TorchTensorDeserializer() | ||
|
|
||
| # Reload again to restore normal state | ||
| importlib.reload(base_module) | ||
|
|
||
|
|
||
| def test_torch_tensor_serializer_works_when_torch_installed(): | ||
| """Verify TorchTensorSerializer can be instantiated when torch is available.""" | ||
| from sagemaker.core.serializers.base import TorchTensorSerializer | ||
|
|
||
| serializer = TorchTensorSerializer() | ||
| assert serializer is not None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests ( def test_torch_tensor_serializer_works_when_torch_installed():
pytest.importorskip("torch")
... |
||
| assert serializer.CONTENT_TYPE == "tensor/pt" | ||
|
|
||
|
|
||
| def test_torch_tensor_deserializer_works_when_torch_installed(): | ||
| """Verify TorchTensorDeserializer can be instantiated when torch is available.""" | ||
| from sagemaker.core.deserializers.base import TorchTensorDeserializer | ||
|
|
||
| deserializer = TorchTensorDeserializer() | ||
| assert deserializer is not None | ||
| assert deserializer.ACCEPT == ("tensor/pt",) | ||
|
|
||
|
|
||
| def test_sagemaker_core_imports_without_torch(): | ||
| """Verify that importing serializers/deserializers modules does not fail without torch.""" | ||
| import importlib | ||
| import sagemaker.core.serializers.base as ser_base | ||
| import sagemaker.core.deserializers.base as deser_base | ||
|
|
||
| with patch.dict(sys.modules, {"torch": None}): | ||
| # Reloading the modules should not raise since torch imports are lazy (in __init__) | ||
| importlib.reload(ser_base) | ||
| importlib.reload(deser_base) | ||
|
|
||
| # Restore | ||
| importlib.reload(ser_base) | ||
| importlib.reload(deser_base) | ||
|
|
||
|
|
||
| def test_other_serializers_work_without_torch(): | ||
| """Verify non-torch serializers work normally even if torch is unavailable.""" | ||
| import importlib | ||
| import sagemaker.core.serializers.base as base_module | ||
|
|
||
| with patch.dict(sys.modules, {"torch": None}): | ||
| importlib.reload(base_module) | ||
|
|
||
| csv_ser = base_module.CSVSerializer() | ||
| assert csv_ser.serialize([1, 2, 3]) == "1,2,3" | ||
|
|
||
| json_ser = base_module.JSONSerializer() | ||
| assert json_ser.serialize([1, 2, 3]) == "[1, 2, 3]" | ||
|
|
||
| numpy_ser = base_module.NumpySerializer() | ||
| result = numpy_ser.serialize(np.array([1, 2, 3])) | ||
| assert result is not None | ||
|
|
||
| # Restore | ||
| importlib.reload(base_module) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
allextras group duplicates the torch dependency string. If more optional dependencies are added later, this will need manual sync. Consider referencing thetorchextra fromall:This keeps
allas a meta-extra that automatically includes everything.