Skip to content

feature: Torch dependency in sagameker-core to be made optional (5457)#5707

Closed
aviruthen wants to merge 2 commits intoaws:masterfrom
aviruthen:feature/torch-dependency-in-sagameker-core-to-be-made-5457
Closed

feature: Torch dependency in sagameker-core to be made optional (5457)#5707
aviruthen wants to merge 2 commits intoaws:masterfrom
aviruthen:feature/torch-dependency-in-sagameker-core-to-be-made-5457

Conversation

@aviruthen
Copy link
Copy Markdown
Collaborator

Description

The torch>=1.9.0 dependency in sagemaker-core's pyproject.toml is listed as a hard/required dependency, but torch is only used in 3 places: (1) TorchTensorSerializer.init which already does a lazy from torch import Tensor, (2) TorchTensorDeserializer.init which already does a lazy from torch import from_numpy, and (3) torchrun_driver.py which runs inside a training container (not on the user's machine). All torch imports are already lazy/conditional, so making torch optional requires: moving it from dependencies to [project.optional-dependencies], adding try/except DeferredError pattern to the TorchTensorDeserializer (it currently raises a bare Exception), and ensuring the serializer/deserializer classes give clear error messages when torch is not installed.

Related Issue

Related issue: 5457

Changes Made

  • sagemaker-core/pyproject.toml
  • sagemaker-core/src/sagemaker/core/serializers/base.py
  • sagemaker-core/src/sagemaker/core/deserializers/base.py
  • sagemaker-core/tests/unit/test_torch_optional_dependency.py

AI-Generated PR

This PR was automatically generated by the PySDK Issue Agent.

  • Confidence score: 85%
  • Classification: type: feature request
  • SDK version target: V3

Merge Checklist

  • Changes are backward compatible
  • Commit message follows prefix: description format
  • Unit tests added/updated
  • Integration tests added (if applicable)
  • Documentation updated (if applicable)

Copy link
Copy Markdown
Collaborator

@sagemaker-bot sagemaker-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review

This PR makes torch an optional dependency in sagemaker-core, which is a good change since torch is a heavy dependency only needed for specific serializer/deserializer classes. The implementation is mostly correct but has a few issues: the exception chaining pattern loses the original traceback, the test file has a line length violation, and the importlib.reload pattern in tests can cause flaky behavior in CI.

self.convert_npy_to_tensor = from_numpy
except ImportError:
raise Exception("Unable to import pytorch.")
raise ImportError(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The raise ImportError(...) inside an except ImportError block loses the original exception context. Use raise ImportError(...) from e to preserve the exception chain, which helps with debugging:

except ImportError as e:
    raise ImportError(
        "Unable to import torch. Please install torch to use TorchTensorDeserializer: "
        "pip install 'sagemaker-core[torch]'"
    ) from e

from torch import Tensor
try:
from torch import Tensor
except ImportError:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue here — use raise ... from e to preserve the exception chain:

except ImportError as e:
    raise ImportError(
        "Unable to import torch. Please install torch to use TorchTensorSerializer: "
        "pip install 'sagemaker-core[torch]'"
    ) from e

"""Verify TorchTensorDeserializer raises ImportError with helpful message when torch is missing."""
import importlib
import sagemaker.core.deserializers.base as base_module

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring exceeds the 100-character line length limit. Consider wrapping it:

def test_torch_tensor_deserializer_raises_import_error_when_torch_missing():
    """Verify TorchTensorDeserializer raises ImportError when torch is missing."""

from sagemaker.core.serializers.base import TorchTensorSerializer

serializer = TorchTensorSerializer()
assert serializer is not None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests (test_torch_tensor_serializer_works_when_torch_installed and test_torch_tensor_deserializer_works_when_torch_installed) will fail in CI environments where torch is not installed. Since torch is now optional, the test environment may not have it. Consider guarding these with pytest.importorskip("torch") at the top of each test:

def test_torch_tensor_serializer_works_when_torch_installed():
    pytest.importorskip("torch")
    ...

"pylint>=3.0.0, <4.0.0"
]
torch = [
"torch>=1.9.0",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The all extras group duplicates the torch dependency string. If more optional dependencies are added later, this will need manual sync. Consider referencing the torch extra from all:

all = [
    "sagemaker-core[torch]",
]

This keeps all as a meta-extra that automatically includes everything.

from __future__ import absolute_import

import sys
from unittest.mock import patch, MagicMock
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: MagicMock is imported but never used. Remove the unused import.

@aws aws deleted a comment from sagemaker-bot Mar 31, 2026
@aviruthen
Copy link
Copy Markdown
Collaborator Author

🤖 Iteration #1 — Review Comments Addressed

Description

Make torch an optional dependency in sagemaker-core instead of a required one.

The torch>=1.9.0 dependency was listed as a hard/required dependency, but torch is only used in a few places where imports are already lazy/conditional:

  1. TorchTensorSerializer.__init__ - lazy from torch import Tensor
  2. TorchTensorDeserializer.__init__ - lazy from torch import from_numpy
  3. torchrun_driver.py - runs inside a training container, not on the user's machine

This change significantly reduces the install footprint of sagemaker-core for users who don't need torch functionality.

Changes Made

  • sagemaker-core/pyproject.toml: torch is already in [project.optional-dependencies] under torch and all extras. Updated all to reference sagemaker-core[torch] instead of duplicating the dependency string.
  • sagemaker-core/src/sagemaker/core/serializers/base.py: Preserved exception chain with raise ... from e in TorchTensorSerializer.__init__.
  • sagemaker-core/src/sagemaker/core/deserializers/base.py: Preserved exception chain with raise ... from e in TorchTensorDeserializer.__init__.
  • sagemaker-core/tests/unit/test_torch_optional_dependency.py: Fixed docstring line length, added pytest.importorskip("torch") guards for tests that require torch, and removed unused MagicMock import.

Installation

Users who need torch functionality can install it via:

pip install 'sagemaker-core[torch]'
# or
pip install 'sagemaker-core[all]'

Users who don't need torch get a leaner installation:

pip install sagemaker-core

Clear error messages are provided when attempting to use torch-dependent classes without torch installed.

Comments reviewed: 6
Files modified: sagemaker-core/pyproject.toml, sagemaker-core/src/sagemaker/core/deserializers/base.py, sagemaker-core/src/sagemaker/core/serializers/base.py, sagemaker-core/tests/unit/test_torch_optional_dependency.py

  • sagemaker-core/pyproject.toml: Change 'all' extras to reference 'torch' extra instead of duplicating the dependency string
  • sagemaker-core/src/sagemaker/core/deserializers/base.py: Preserve exception chain with 'from e' in TorchTensorDeserializer
  • sagemaker-core/src/sagemaker/core/serializers/base.py: Preserve exception chain with 'from e' in TorchTensorSerializer
  • sagemaker-core/tests/unit/test_torch_optional_dependency.py: Fix docstring line length, add pytest.importorskip guards, remove unused MagicMock import

@aviruthen aviruthen closed this Mar 31, 2026
@aviruthen aviruthen deleted the feature/torch-dependency-in-sagameker-core-to-be-made-5457 branch March 31, 2026 23:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants