feature: Torch dependency in sagameker-core to be made optional (5457)#5707
feature: Torch dependency in sagameker-core to be made optional (5457)#5707aviruthen wants to merge 2 commits intoaws:masterfrom
Conversation
sagemaker-bot
left a comment
There was a problem hiding this comment.
🤖 AI Code Review
This PR makes torch an optional dependency in sagemaker-core, which is a good change since torch is a heavy dependency only needed for specific serializer/deserializer classes. The implementation is mostly correct but has a few issues: the exception chaining pattern loses the original traceback, the test file has a line length violation, and the importlib.reload pattern in tests can cause flaky behavior in CI.
| self.convert_npy_to_tensor = from_numpy | ||
| except ImportError: | ||
| raise Exception("Unable to import pytorch.") | ||
| raise ImportError( |
There was a problem hiding this comment.
The raise ImportError(...) inside an except ImportError block loses the original exception context. Use raise ImportError(...) from e to preserve the exception chain, which helps with debugging:
except ImportError as e:
raise ImportError(
"Unable to import torch. Please install torch to use TorchTensorDeserializer: "
"pip install 'sagemaker-core[torch]'"
) from e| from torch import Tensor | ||
| try: | ||
| from torch import Tensor | ||
| except ImportError: |
There was a problem hiding this comment.
Same issue here — use raise ... from e to preserve the exception chain:
except ImportError as e:
raise ImportError(
"Unable to import torch. Please install torch to use TorchTensorSerializer: "
"pip install 'sagemaker-core[torch]'"
) from e| """Verify TorchTensorDeserializer raises ImportError with helpful message when torch is missing.""" | ||
| import importlib | ||
| import sagemaker.core.deserializers.base as base_module | ||
|
|
There was a problem hiding this comment.
This docstring exceeds the 100-character line length limit. Consider wrapping it:
def test_torch_tensor_deserializer_raises_import_error_when_torch_missing():
"""Verify TorchTensorDeserializer raises ImportError when torch is missing."""| from sagemaker.core.serializers.base import TorchTensorSerializer | ||
|
|
||
| serializer = TorchTensorSerializer() | ||
| assert serializer is not None |
There was a problem hiding this comment.
These tests (test_torch_tensor_serializer_works_when_torch_installed and test_torch_tensor_deserializer_works_when_torch_installed) will fail in CI environments where torch is not installed. Since torch is now optional, the test environment may not have it. Consider guarding these with pytest.importorskip("torch") at the top of each test:
def test_torch_tensor_serializer_works_when_torch_installed():
pytest.importorskip("torch")
...| "pylint>=3.0.0, <4.0.0" | ||
| ] | ||
| torch = [ | ||
| "torch>=1.9.0", |
There was a problem hiding this comment.
The all extras group duplicates the torch dependency string. If more optional dependencies are added later, this will need manual sync. Consider referencing the torch extra from all:
all = [
"sagemaker-core[torch]",
]This keeps all as a meta-extra that automatically includes everything.
| from __future__ import absolute_import | ||
|
|
||
| import sys | ||
| from unittest.mock import patch, MagicMock |
There was a problem hiding this comment.
Minor: MagicMock is imported but never used. Remove the unused import.
🤖 Iteration #1 — Review Comments AddressedDescriptionMake The
This change significantly reduces the install footprint of Changes Made
InstallationUsers who need torch functionality can install it via: pip install 'sagemaker-core[torch]'
# or
pip install 'sagemaker-core[all]'Users who don't need torch get a leaner installation: pip install sagemaker-coreClear error messages are provided when attempting to use torch-dependent classes without torch installed. Comments reviewed: 6
|
Description
The
torch>=1.9.0dependency in sagemaker-core's pyproject.toml is listed as a hard/required dependency, but torch is only used in 3 places: (1) TorchTensorSerializer.init which already does a lazyfrom torch import Tensor, (2) TorchTensorDeserializer.init which already does a lazyfrom torch import from_numpy, and (3) torchrun_driver.py which runs inside a training container (not on the user's machine). All torch imports are already lazy/conditional, so making torch optional requires: moving it fromdependenciesto[project.optional-dependencies], adding try/except DeferredError pattern to the TorchTensorDeserializer (it currently raises a bare Exception), and ensuring the serializer/deserializer classes give clear error messages when torch is not installed.Related Issue
Related issue: 5457
Changes Made
sagemaker-core/pyproject.tomlsagemaker-core/src/sagemaker/core/serializers/base.pysagemaker-core/src/sagemaker/core/deserializers/base.pysagemaker-core/tests/unit/test_torch_optional_dependency.pyAI-Generated PR
This PR was automatically generated by the PySDK Issue Agent.
Merge Checklist
prefix: descriptionformat