Skip to content

Commit 769d06e

Browse files
committed
feature: Torch dependency in sagameker-core to be made optional (5457)
1 parent 0976df1 commit 769d06e

File tree

4 files changed

+123
-3
lines changed

4 files changed

+123
-3
lines changed

sagemaker-core/pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ dependencies = [
3232
"smdebug_rulesconfig>=1.0.1",
3333
"schema>=0.7.5",
3434
"omegaconf>=2.1.0",
35-
"torch>=1.9.0",
3635
"scipy>=1.5.0",
3736
# Remote function dependencies
3837
"cloudpickle>=2.0.0",
@@ -57,10 +56,17 @@ codegen = [
5756
"pytest>=8.0.0, <9.0.0",
5857
"pylint>=3.0.0, <4.0.0"
5958
]
59+
torch = [
60+
"torch>=1.9.0",
61+
]
62+
all = [
63+
"torch>=1.9.0",
64+
]
6065
test = [
6166
"pytest>=8.0.0, <9.0.0",
6267
"pytest-cov>=4.0.0",
6368
"pytest-xdist>=3.0.0",
69+
"torch>=1.9.0",
6470
]
6571

6672
[project.urls]

sagemaker-core/src/sagemaker/core/deserializers/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,10 @@ def __init__(self, accept="tensor/pt"):
366366

367367
self.convert_npy_to_tensor = from_numpy
368368
except ImportError:
369-
raise Exception("Unable to import pytorch.")
369+
raise ImportError(
370+
"Unable to import torch. Please install torch to use TorchTensorDeserializer: "
371+
"pip install 'sagemaker-core[torch]'"
372+
)
370373

371374
def deserialize(self, stream, content_type="tensor/pt"):
372375
"""Deserialize streamed data to TorchTensor

sagemaker-core/src/sagemaker/core/serializers/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,13 @@ class TorchTensorSerializer(SimpleBaseSerializer):
443443

444444
def __init__(self, content_type="tensor/pt"):
445445
super(TorchTensorSerializer, self).__init__(content_type=content_type)
446-
from torch import Tensor
446+
try:
447+
from torch import Tensor
448+
except ImportError:
449+
raise ImportError(
450+
"Unable to import torch. Please install torch to use TorchTensorSerializer: "
451+
"pip install 'sagemaker-core[torch]'"
452+
)
447453

448454
self.torch_tensor = Tensor
449455
self.numpy_serializer = NumpySerializer()
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tests for torch optional dependency behavior."""
14+
from __future__ import absolute_import
15+
16+
import sys
17+
from unittest.mock import patch, MagicMock
18+
19+
import numpy as np
20+
import pytest
21+
22+
23+
def test_torch_tensor_serializer_raises_import_error_when_torch_missing():
24+
"""Verify TorchTensorSerializer raises ImportError with helpful message when torch is missing."""
25+
import importlib
26+
import sagemaker.core.serializers.base as base_module
27+
28+
with patch.dict(sys.modules, {"torch": None}):
29+
# Reload to clear any cached imports
30+
importlib.reload(base_module)
31+
with pytest.raises(ImportError, match="pip install 'sagemaker-core\\[torch\\]'"):
32+
base_module.TorchTensorSerializer()
33+
34+
# Reload again to restore normal state
35+
importlib.reload(base_module)
36+
37+
38+
def test_torch_tensor_deserializer_raises_import_error_when_torch_missing():
39+
"""Verify TorchTensorDeserializer raises ImportError with helpful message when torch is missing."""
40+
import importlib
41+
import sagemaker.core.deserializers.base as base_module
42+
43+
with patch.dict(sys.modules, {"torch": None}):
44+
importlib.reload(base_module)
45+
with pytest.raises(ImportError, match="pip install 'sagemaker-core\\[torch\\]'"):
46+
base_module.TorchTensorDeserializer()
47+
48+
# Reload again to restore normal state
49+
importlib.reload(base_module)
50+
51+
52+
def test_torch_tensor_serializer_works_when_torch_installed():
53+
"""Verify TorchTensorSerializer can be instantiated when torch is available."""
54+
from sagemaker.core.serializers.base import TorchTensorSerializer
55+
56+
serializer = TorchTensorSerializer()
57+
assert serializer is not None
58+
assert serializer.CONTENT_TYPE == "tensor/pt"
59+
60+
61+
def test_torch_tensor_deserializer_works_when_torch_installed():
62+
"""Verify TorchTensorDeserializer can be instantiated when torch is available."""
63+
from sagemaker.core.deserializers.base import TorchTensorDeserializer
64+
65+
deserializer = TorchTensorDeserializer()
66+
assert deserializer is not None
67+
assert deserializer.ACCEPT == ("tensor/pt",)
68+
69+
70+
def test_sagemaker_core_imports_without_torch():
71+
"""Verify that importing serializers/deserializers modules does not fail without torch."""
72+
import importlib
73+
import sagemaker.core.serializers.base as ser_base
74+
import sagemaker.core.deserializers.base as deser_base
75+
76+
with patch.dict(sys.modules, {"torch": None}):
77+
# Reloading the modules should not raise since torch imports are lazy (in __init__)
78+
importlib.reload(ser_base)
79+
importlib.reload(deser_base)
80+
81+
# Restore
82+
importlib.reload(ser_base)
83+
importlib.reload(deser_base)
84+
85+
86+
def test_other_serializers_work_without_torch():
87+
"""Verify non-torch serializers work normally even if torch is unavailable."""
88+
import importlib
89+
import sagemaker.core.serializers.base as base_module
90+
91+
with patch.dict(sys.modules, {"torch": None}):
92+
importlib.reload(base_module)
93+
94+
csv_ser = base_module.CSVSerializer()
95+
assert csv_ser.serialize([1, 2, 3]) == "1,2,3"
96+
97+
json_ser = base_module.JSONSerializer()
98+
assert json_ser.serialize([1, 2, 3]) == "[1, 2, 3]"
99+
100+
numpy_ser = base_module.NumpySerializer()
101+
result = numpy_ser.serialize(np.array([1, 2, 3]))
102+
assert result is not None
103+
104+
# Restore
105+
importlib.reload(base_module)

0 commit comments

Comments
 (0)